Skip to main content

Command Palette

Search for a command to run...

Decoding: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Updated
5 min read

Vision Transformer (ViT) – High-level Take-aways

  • Main problem addressed
    Convolutional Neural Networks (CNNs) dominate vision, yet they embed hand-crafted inductive biases (locality, translation equivariance) that may limit scalability. The paper asks: Can a standard NLP Transformer, with minimal changes, match or surpass state-of-the-art CNNs on image classification when trained at scale?

  • Core idea

    1. Treat an image as a sequence of non-overlapping patches $P \times P$ (e.g., 16×16) and linearly project each flattened patch to a D-dimensional token embedding (Eq. 1, p. 3).

    2. Prepend a learnable “[class]” token, add 1-D positional embeddings, and feed the resulting sequence into a vanilla Transformer encoder, unchanged from BERT except for patch sizing (Fig. 1, p. 2).

    3. Supervised pre-training on very large image corpora (ImageNet-21k, 14 M images; JFT-300M, 303 M images) compensates for the lack of CNN inductive bias.

    4. After fine-tuning (often at higher resolution) the resulting Vision Transformer attains competitive or superior accuracy to strong CNN baselines while using 2–4× less pre-training compute (Table 2, p. 5; Fig. 5, p. 7).

  • Claimed contributions

    • Introduce Vision Transformer (pure Transformer on image patches)

    • Show scaling law: larger datasets > inductive bias for vision

    • Achieve 88.55 % ImageNet top-1 with ViT-H/14 (state-of-the-art at submission)

    • Demonstrate strong transfer: 94.55 % CIFAR-100, 77.63 % VTAB mean

    • Provide compute-efficient training (≥2× less TPU core-days than prior SOTA)

    • Release code & pre-trained checkpoints for reproducibility

Architecture Details

High Level Block Diagram

  • Linear Projection of Flattened Patches – cuts the image into non-overlapping P×P patches, flattens each, and maps it to a D-dimensional vector.

  • [class] Embedding – a learned vector prepended to the patch sequence whose final state will carry the whole-image representation.

  • Patch + Position Embedding – adds learnable 1-D positional embeddings to every token (patch or class).

  • Transformer Encoder (stack repeated L times) – alternates Multi-Head Self-Attention (MSA), LayerNorm and MLP sub-blocks with residual connections.

  • MLP Head – a task-specific classifier (1-layer during fine-tuning) that transforms the final [class] representation into logits over K classes.

Detail Data Fow

  • Input Image → Linear Projection of Flattened Patches

    • - Operation: Slice image into P×P non-overlapping patches, flatten each, apply a linear map.
      - Inputs: H×W×C image.
      - Outputs: N patch embeddings ∈ℝᴺ×ᴰ.

    • [class] Embedding → Sequence Concatenation

      • - Operation: Prepend a learned [class] vector to patch embeddings.
        - Inputs: [class] token (1×D), patch embeddings (N×D).
        - Outputs: (N+1)-token sequence ∈ℝᴺ⁺¹×ᴰ.
    • Position Embedding Addition → Patch + Position Embedding

      • - Operation: Add learnable 1-D positional vectors elementwise to each token.
        - Inputs: token sequence, positional table (N+1×D).
        - Outputs: position-encoded sequence ∈ℝᴺ⁺¹×ᴰ.
    • Sequence → Transformer Encoder (Layer 1: Multi-Head Attention)

      • - Operation: Each token attends to all tokens; produces context-mixed representations.
        - Inputs: position-encoded sequence.
        - Outputs: attention-updated sequence (same shape).
    • Residual Add & Norm (Layer 1)

      • Operation: Add skip connection and apply LayerNorm.

      • Inputs: previous layer input and attention output.

      • Outputs: normalized sequence.

    • MLP Block (Layer 1)

      • Operation: Two dense layers with GELU, projecting D→D_MLP→D.

      • Inputs: normalized sequence.

      • Outputs: feed-forward updated sequence.

    • Residual Add & Norm (Layer 1)

      • Operation: Second skip connection + LayerNorm.

      • Inputs: attention-normalized sequence and MLP output.

      • Outputs: sequence for next layer.

    • Steps 4–7 repeat L − 1 times in Transformer Encoder

      • Operation: Deeper context mixing via identical layers.

      • Inputs: sequence from previous layer.

      • Outputs: final encoded sequence z_L ∈ℝᴺ⁺¹×ᴰ.

    • Extract [class] Token → LN

      • Operation: Apply LayerNorm to final [class] vector.

      • Inputs: first token of z_L.

      • Outputs: image representation y ∈ℝᴰ.

    • Representation y → MLP Head

      • Operation: Single linear layer (fine-tuning) mapping D→K logits.

      • Inputs: y.

      • Outputs: class-score vector logits ∈ℝᴷ.

    • Logits → Softmax (not drawn)

      • Operation: Convert logits to class probabilities.

      • Inputs: logits.

      • Outputs: probability distribution over K classes.

The Extra [class] Token in Vision Transformer

What exactly is fed in as the [class] token?

  • It is not derived from the image.

  • It is a single, learned D-dimensional vector (same width as any patch embedding) that is randomly initialized along with the rest of the model parameters.

  • During every forward pass the token is simply prepended to the patch-embedding sequence, so the input length becomes N + 1.

Why add it all?

  • Global pooling substitute Transformers output a vector per token; one must be chosen to represent the whole image. The [class] token gives the model a designated slot whose final hidden state becomes that representation.

  • Information sink Because this token attends to—and is attended by—all patch tokens in each self-attention layer, it can collect a summary of the entire image content.

  • Compatibility with NLP tooling BERT uses an identical mechanism for sentence-level tasks; ViT inherits that convention and re-uses the same code path.

How does the token “learn” ?

  • Parameter learning

    • Its initial embedding vector is optimized like any other weight through back-propagation.

    • With L layers of self-attention, the token is repeatedly updated by mixing with patch tokens; gradients therefore reflect the classification loss.

  • Representation learning

    • Early layers let the token gather coarse image context.

    • Deeper layers refine it into a discriminative vector.

    • The final LayerNorm output (y = LN(zₗ)) is what the MLP head reads.

  • No special loss term Learning is entirely driven by the downstream cross-entropy (or whatever task loss) applied to the head’s log

Couldn’t we just use global average pooling instead?

Yes—and the authors tried it. A GAP-based head works almost identically if you tune the learning rate, but the [class] token keeps the architecture identical to BERT and avoids extra pooling code. It also lets the model decide how much emphasis to place on each patch rather than averaging blindly.

Putting it Together – Algorithm Flow

  1. Patchify & Embed Apply Eq.(1) to build the token sequence with positional info.

  2. Repeat for ℓ = 1…L 2.1 Self-Attention with skip → Eq.(2) 2.2 MLP with skip  → Eq.(3)

  3. Readout Take the Layer-NORMed class token (Eq.(4)).

  4. Head Pass $y$ through a linear or MLP layer to predict class logits.

References

  1. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

  2. Reading with AI Tool