Перейти к содержанию

Vision Transformers (ViT): от патчей к SOTA

~8 минут чтения

Предварительно: Реализация внимания с нуля | Позиционное кодирование

Vision Transformer разрезает изображение 224x224 на 196 патчей (16x16 пикселей каждый), проецирует их в embedding-пространство и обрабатывает self-attention -- точно как слова в тексте. На ImageNet-1K (1.3M изображений) ViT проигрывает CNN аналогичного размера, но при pretraining на JFT-300M масштабируется лучше. DeiT (2021) сделал ViT практичным без огромных данных: knowledge distillation от CNN-учителя дает 83.1% top-1 на ImageNet. Swin Transformer добавил hierarchical window attention (\(O(N)\) вместо \(O(N^2)\)), став стандартным backbone для detection и segmentation. Сегодня ViT -- обязательный vision encoder в multimodal моделях (CLIP, GPT-4V, LLaVA, Gemini).


Главная идея

Vision Transformer (ViT) -- применение архитектуры трансформера к изображениям. Вместо пикселей модель работает с патчами (кусочками изображения), которые превращаются в последовательность токенов -- точно как слова в тексте.

Аналогия: представьте, что вы разрезали фотографию на 196 квадратиков (14x14 сетка) и выложили их в ряд. Каждый квадратик -- это "слово", а self-attention определяет, какие квадратики важны для понимания всей картинки.

Aha-момент: CNN обрабатывают изображение иерархически -- сначала локальные паттерны (edges), потом глобальные (объекты). ViT с первого слоя видит ВСЮ картинку через self-attention. Это и преимущество (глобальный контекст), и недостаток (нужно больше данных, чтобы научиться тому, что CNN получают бесплатно через inductive bias).


Patch Embedding: как изображение становится последовательностью

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2  # 196 для 224/16
        # Свертка с kernel=stride=patch_size -- нарезает и проецирует за один шаг
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        # [B, 3, 224, 224] -> [B, 768, 14, 14] -> [B, 196, 768]
        x = self.proj(x).flatten(2).transpose(1, 2)
        # Добавляем [CLS] token
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)  # [B, 197, 768]
        # Добавляем positional embedding
        x = x + self.pos_embed
        return x

Размерность Patch Embedding

Изображение \(H \times W \times C\) с patch size \(P\):

\[N = \frac{H \times W}{P^2}\]

Каждый патч: \(P \times P \times C\) пикселей \(\rightarrow\) линейная проекция \(\rightarrow\) вектор размерности \(D\).

Пример: $224 \times 224 \times 3$, $P=16$: $N = 224^2 / 16^2 = 196$ патчей, каждый $16 \times 16 \times 3 = 768$ значений $\rightarrow$ проекция в $D=768$.

Эволюция Vision Transformers

ViT (2020, Google) -- доказательство концепции

  • Чистый трансформер, никаких сверток (кроме patch projection)
  • Проблема: нужен JFT-300M (300M изображений) для хороших результатов
  • На ImageNet-1K (1.3M изображений) проигрывает CNN аналогичного размера

DeiT (2021, Meta) -- сделал ViT практичным

Ключевое нововведение: knowledge distillation от CNN-учителя (RegNet).

  • Добавил distillation token рядом с [CLS] -- учится имитировать предсказания CNN
  • Результат: ViT-уровень качества используя ТОЛЬКО ImageNet-1K
  • DeiT-B достигает 83.1% top-1 на ImageNet без внешних данных

Swin Transformer (2021, Microsoft) -- иерархическая структура

graph LR
    IMG["Image<br/>224x224"] --> S1["Stage 1<br/>56x56, C=96"]
    S1 --> S2["Stage 2<br/>28x28, C=192"]
    S2 --> S3["Stage 3<br/>14x14, C=384"]
    S3 --> S4["Stage 4<br/>7x7, C=768"]

    style IMG fill:#e8eaf6,stroke:#3f51b5
    style S1 fill:#e8f5e9,stroke:#4caf50
    style S2 fill:#e8f5e9,stroke:#4caf50
    style S3 fill:#fff3e0,stroke:#ef6c00
    style S4 fill:#fce4ec,stroke:#c62828
  • Shifted window attention вместо global: O(N) вместо O(N^2)
  • Hierarchical: увеличивает channels, уменьшает resolution (как CNN)
  • Стал backbone для detection (Mask R-CNN + Swin) и segmentation

DINO / DINOv2 (2021/2023, Meta) -- self-supervised ViT

  • Self-distillation: student и teacher -- обе ViT, teacher -- EMA student
  • Без меток учит features, которые уже содержат семантику объектов
  • Attention maps DINO визуально сегментируют объекты -- без обучения на segmentation
  • DINOv2 (2023) -- улучшенная версия, foundation model для vision features

SAM (2023, Meta) -- Segment Anything

  • Promptable segmentation: клик, bbox, текст \(\rightarrow\) маска объекта
  • ViT-H encoder + lightweight mask decoder
  • Обучен на SA-1B (1 billion масок, 11M изображений)
  • Zero-shot generalization -- работает на unseen domains без fine-tuning

ViT vs CNN: когда что выбирать

Критерий CNN (ResNet, EfficientNet) ViT / Swin Когда что
Малые данные (< 10K) Лучше (inductive bias) Хуже CNN + аугментация
Большие данные (> 1M) Упирается в потолок Масштабируется лучше ViT + pretraining
Скорость inference Быстрее (conv оптимизирован) Медленнее (attention O(N^2)) CNN для edge/mobile
Transfer learning ImageNet pretrained ViT pretrained (DINOv2) ViT, если есть хороший pretrained
Detection/Segmentation ResNet + FPN Swin + FPN Swin на 2-3% лучше
Multimodal Не подходит Стандарт (CLIP, GPT-4V) Только ViT

ViT НЕ всегда лучше CNN

При < 100K изображений и без pretrained модели CNN (ResNet-50, EfficientNet) часто побеждает ViT. Inductive bias CNN (locality, translation equivariance) -- это бесплатное знание о структуре изображений. ViT должен выучить это из данных. При достаточном количестве данных ViT выигрывает, но "достаточно" -- это миллионы примеров или хороший pretrained backbone.


Multimodal: ViT как vision encoder

Современные multimodal модели используют ViT как зрительный компонент:

Модель Vision Encoder Как связан с LLM
CLIP (OpenAI, 2021) ViT-L/14 Contrastive learning: image embedding ~ text embedding
GPT-4V/o (OpenAI) Undisclosed ViT Image tokens подаются в LLM напрямую
LLaVA (2023) CLIP ViT-L + MLP projection Visual tokens \(\rightarrow\) проекция \(\rightarrow\) LLM input
Gemini (Google) Undisclosed, likely ViT Natively multimodal
Claude 3.5 (Anthropic) Undisclosed Vision + language interleaved

Паттерн: ViT кодирует изображение в последовательность embedding'ов \(\rightarrow\) проекция в пространство LLM \(\rightarrow\) LLM обрабатывает visual tokens наравне с текстовыми.


  • Scaling ViT: ViT-22B (Google, 2023) -- 22 миллиарда параметров, 4B image-text pairs
  • Efficient ViT: EfficientViT, FastViT -- оптимизация для mobile/edge
  • MAE (Masked Autoencoders): self-supervised pretraining -- маскируем 75% патчей, восстанавливаем
  • FlexiViT: ViT с переменным patch size -- одна модель для разных разрешений
  • SigLIP: замена CLIP с sigmoid loss (без softmax) -- лучше масштабируется

Типичные ошибки

Путать ViT и Swin

ViT использует global self-attention (каждый патч видит все остальные). Swin -- window attention (патч видит только соседей в окне 7x7). У них разная вычислительная сложность: ViT O(N^2), Swin O(N). Для detection/segmentation Swin предпочтительнее.

Fine-tuning ViT с нуля на малых данных

ViT без pretrained весов на 10K изображений -- гарантированный overfitting. Используйте pretrained backbone (DINOv2, CLIP) и fine-tune только классификационную голову или последние слои. Это не "хак", это стандартная практика.

Заблуждение: patch size не влияет на качество

Patch size \(P\) определяет trade-off между детализацией и compute. При \(P=16\) изображение 224x224 дает 196 токенов. При \(P=8\) -- 784 токена (4x больше), а attention стоит \(O(N^2)\) -- в 16x дороже. Но мелкие патчи критичны для задач, где важны детали (медицинские изображения, OCR, мелкие объекты). ViT-22B и FlexiViT решают это: одна модель работает с разными patch sizes. Выбор \(P\) -- осознанное архитектурное решение, а не default.


Interview Questions

1. Как ViT превращает изображение в последовательность токенов?

❌ Red flag: "ViT подает пиксели напрямую в трансформер"

✅ Strong answer: "Изображение \(H \times W\) нарезается на патчи \(P \times P\). Каждый патч (например, \(16 \times 16 \times 3 = 768\) значений) проецируется через линейный слой (Conv2d с kernel=stride=P) в embedding размерности \(D\). Добавляется [CLS] token для классификации и learned positional embeddings. 224x224 с P=16 = 196 патчей + 1 CLS = 197 токенов. Далее -- стандартный трансформер-энкодер."

2. Почему ViT требует больше данных чем CNN для тренировки с нуля?

❌ Red flag: "ViT просто больше по параметрам"

✅ Strong answer: "CNN имеет inductive bias: locality (свертки смотрят на соседние пиксели) и translation equivariance (паттерн распознается в любом месте изображения). Это бесплатное знание о структуре изображений. ViT с global self-attention должен выучить эти свойства из данных. На ImageNet-1K (1.3M) ViT проигрывает CNN, но при pretraining на JFT-300M или LAION-5B масштабируется лучше. DeiT решает это через distillation от CNN-учителя: 83.1% top-1 без внешних данных."

3. Объясните shifted window attention в Swin Transformer.

❌ Red flag: "Swin просто считает attention в маленьких окнах"

✅ Strong answer: "В нечетных слоях attention считается внутри фиксированных окон (7x7 патчей) -- \(O(N)\) вместо \(O(N^2)\). Но без cross-window связи информация не перетекает между окнами. Решение: в четных слоях окна сдвигаются на половину размера -- каждый патч теперь взаимодействует с соседями из других окон предыдущего слоя. Плюс hierarchical structure: 4 стадии с увеличением channels и уменьшением resolution (как в CNN). Это делает Swin стандартным backbone для detection и segmentation."

4. Вы проектируете vision pipeline для мультимодальной модели. Как выберете vision encoder?

❌ Red flag: "Возьмем ResNet -- он хорошо изучен"

✅ Strong answer: "Pretrained CLIP ViT-L/14 или DINOv2 как backbone -- оба дают rich semantic features без supervised labels. Projection layer (MLP) для маппинга visual embeddings в пространство LLM. Разрешение: больше патчей = лучше детализация, но квадратично дороже по compute. Для real-time: EfficientViT или SigLIP (sigmoid loss вместо softmax, лучше масштабируется). LLaVA-паттерн: CLIP ViT + 2-layer MLP + LLM. CNN не подходит: нет sequence output для подачи в LLM."

Самопроверка

У вас датасет из 50K медицинских рентгеновских снимков. Нужно классифицировать патологии (15 классов). Какую архитектуру выберете и почему?

Подсказки: (1) данных немного для ViT с нуля, (2) есть ли pretrained на медицинских данных?, (3) какой inductive bias полезен для рентгенов?


See Also

Sources

  1. An Image is Worth 16x16 Words (Dosovitskiy et al., ICLR 2021) -- оригинальный ViT
  2. DeiT (Touvron et al., 2021) -- data-efficient training
  3. Swin Transformer (Liu et al., ICCV 2021) -- hierarchical ViT
  4. DINO (Caron et al., 2021) -- self-supervised ViT
  5. DINOv2 (Oquab et al., 2023) -- foundation vision features
  6. SAM (Kirillov et al., 2023) -- Segment Anything
  7. Scaling Vision Transformers (Dehghani et al., 2023) -- ViT-22B