AIdventure - ViT - The Vision Transformer

ViT - The Vision Transformer

ViT - The Vision Transformer
Mario Parreño#computer vision#transformer#paper#encoder

Self-attention-based architectures, in particular Transformers, have become the model of choice in natural language processing (NLP). The predominant approach is to pre-train on a large text corpus and then fine-tune on a smaller task-specific dataset.

ViT experiments with applying a standard Transformer directly to images, with the fewest possible modifications. To do so, ViT splits an image into patches and provide the sequence of linear embeddings of these patches as an input to a Transformer. Image patches are treatedd the same way as tokens (words) in NLP applications. ViT model is trained on image classification in supervised fashion.

When trained on mid-sized datasets ViT yields modes accuracies: Transformers lack some of the inductive biases inherent to CNNs, such as translation equivariance and locality, and thereforedo not generalize well when trained on insufficient amounts of data. However, the picture changes if ViT is trained on larger datasets, trumping inductive biases, achieving excellent results.

Architecture

ViT’s model architecture is a multi-layer bidirectional Transformer Encoder, following the original Transformer design as closely as possible. If you want more details about the Transformer architecture, you can check out my Transformer blog post.

ViT denotes the number of the Transformer encoder blocks as LL, the hidden size as HH, and the number of self-attention heads as AA. ViT initial model designs are the following:

ViT model configurations.

Model NameLL (Transformer blocks)HH (Hidden size)AA (Self-Attention heads)
ViT-Base1276812
ViT-Large24102416
ViT-Huge32128016
ViT model architecture. The model receives an image as input and outputs the result of its classification. The vector corresponding to the `[CLS]` token is used as the aggregate sequence representation of the image for the classificaction.
ViT model architecture. The model receives an image as input and outputs the result of its classification. The vector corresponding to the `[CLS]` token is used as the aggregate sequence representation of the image for the classificaction.

The Input: Patch Embedding

The standard Transformer receives as input a 1D sequence of token embeddings. Naive application of self-attention to images would require that each pixel attends to every other pixel. With quadratic cost in the number of pixels, this does not scale to realistic input sizes.

To handle 2D images, we reshape each image into a sequence of flattened 2D patches of an arbitrary size. The resulting number of patches serves as the effective input sequence length for the Transformer. The Transformer uses constant latent vector size HH through all of its layers, so the patches are mapped to HH dimensions with a trainable linear projection. ViT refers to the output of this projection as the patch embeddings.

A simple way to perform patch division and prediction is through 2D convolution. The kernel will be of the desired patch size, introducing a stride of the patch size as well, so that there is no overlap between the pixels. On the other hand, to achieve HH dimensions, we will just need to set HH (also known as n_embd) as the desired output channels of the 2D convolution operation.

class PatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.projection = nn.Conv2d(
            in_channels=config.in_channels, out_channels=config.n_embd,
            kernel_size=config.patch_size, stride=config.patch_size
        )
        self.n_patches = (config.img_size // config.patch_size) ** 2

    def forward(self, x):  # (batch, in_channels, height, width)
        x = self.projection(x)  # (batch, n_embd, patches, patches)
        x = x.flatten(2)  # (batch, n_embd, n_patches) - "stacking" all patches
        x = x.transpose(1, 2)  # (batch, n_patches, n_embd)
        return x

Adapt to Higher Resolutions

It is often beneficial to fine-tune at higher resolution than pre-training. When feeding images of higher resolution, authors keep the patch size the same, which results in a larger effective sequence length.

ViT can handle arbitrary sequence lengths (up to memory constraints), however, the pre-trained positional embeddings may no longer be meaningful. We therefore perform 2D interpolation of the pre-trained positional embeddings, according to their location in the original image. Note that this resolution adjustmen and patch extraction are the only points at which an inductive bias about the 2D structure of the images is manually injected into the model.

# example: 224:(14x14+1) -> 384: (24x24+1) (patch size 16x16)
def resize_pos_embed(posemb, ntok_new):
    """
      Resize the grid of position embeddings to accommodate a new number of tokens.

      Args:
      - posemb (torch.Tensor): Position embeddings (batch_size, seq_length, n_embd).
      - ntok_new (int): New number of tokens/seq_length.

      Returns:
      - torch.Tensor: Resized position embeddings (batch_size, ntok_new, n_embd).
    """
    # Rescale the grid of position embeddings when loading - 24x24+1
    # posemb_clas is for cls token, posemb_grid for the following tokens
    posemb_clas, posemb_grid = posemb[:, :1], posemb[0, 1:]
    ntok_new -= 1

    gsize_old = int(math.sqrt(len(posemb_grid)))  # 14
    gsize_new = int(math.sqrt(ntok_new))          # 24

    # [1, 196, n_embd] -> [1, 14, 14, n_embd] -> [1, n_embd, 14, 14]
    posemb_grid = posemb_grid.reshape(1, gsize_old, gsize_old, -1).permute(0, 3, 1, 2)
    # [1, n_embd, 14, 14] -> [1, n_embd, 24, 24]
    posemb_grid = F.interpolate(posemb_grid, size=(gsize_new, gsize_new), mode='bicubic')
    # [1, n_embd, 24, 24] -> [1, 24×24, n_embd]
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gsize_new * gsize_new, -1)
    # [1, 24×24+1, n_embd]
    posemb = torch.cat([posemb_cls, posemb_grid], dim=1)
    return posemb

Training

Similar to BERT’s [class] token, ViT prepends a learnable embedding to the sequence of embedded patches, whose state at the output of the Transformer encoder serves as the image representation. Both during pre-training and fine-tuning, a classification head is attached to the image representation [class] output. The classification head is implemented by a MLP with one hidden layer at pre-training time and by a single layer at fine-tuning time.

Position embeddings are added to the patch embeddings to retain positional information. Standard learnable 1D position embeddings are used.

The authors note that ViT only performs well when trained on huge datasets with millions of images. Specifically, ResNets perform better with smaller pre-training datasets but plateau sooner than ViT, which performs better with larger pre-training. This result reinforces the intuition that the convolutional inductive bias is useful for smaller datasets, but for larger ones, learning the relevant patterns directly from data is sufficient, even beneficial.

Inductive Bias

ViT has much less image-specific inductive bias than Convolutional Neural Networks (CNNs). In CNNs, locality, two-dimensional neighborhood structure, and translation equivariance are baked into each layer throughout the whole model. In ViT, only MLP layers are local and translationally equivariant, while the self-attention layers are global.

The position embeddings at initialization time carry no information about the 2D positions of the patches and all spatial relations between the patches have to be learned from scratch.

Glossary

  • LL: Number of Transformer encoder blocks.
  • HH: Size of the embeddings. An embedding is a learnable representation of the words of the vocabulary.
  • AA: Number of self-attention heads.
  • num_patches: Input sequence length.

Extra: Code Repository

I have developed a repository with the code of the ViT, with the code of the full model and the training process. You can find it at microViT.

Train your own ViT model from scratch and don’t forget to star the repository if you like it!

Credits

Table of Contents