Self-attention-based architectures, in particular Transformers, have become the model of choice in natural language processing (NLP). The predominant approach is to pre-train on a large text corpus and then fine-tune on a smaller task-specific dataset.
ViT experiments with applying a standard Transformer directly to images, with the fewest possible modifications. To do so, ViT splits an image into patches and provide the sequence of linear embeddings of these patches as an input to a Transformer. Image patches are treatedd the same way as tokens (words) in NLP applications. ViT model is trained on image classification in supervised fashion.
When trained on mid-sized datasets ViT yields modes accuracies: Transformers lack some of the inductive biases inherent to CNNs, such as translation equivariance and locality, and thereforedo not generalize well when trained on insufficient amounts of data. However, the picture changes if ViT is trained on larger datasets, trumping inductive biases, achieving excellent results.
Architecture
ViT’s model architecture is a multi-layer bidirectional Transformer Encoder, following the original Transformer design as closely as possible. If you want more details about the Transformer architecture, you can check out my Transformer blog post.
ViT denotes the number of the Transformer encoder blocks as , the hidden size as , and the number of self-attention heads as . ViT initial model designs are the following:
Model Name | (Transformer blocks) | (Hidden size) | (Self-Attention heads) |
---|---|---|---|
ViT-Base | 12 | 768 | 12 |
ViT-Large | 24 | 1024 | 16 |
ViT-Huge | 32 | 1280 | 16 |
The Input: Patch Embedding
The standard Transformer receives as input a 1D sequence of token embeddings. Naive application of self-attention to images would require that each pixel attends to every other pixel. With quadratic cost in the number of pixels, this does not scale to realistic input sizes.
To handle 2D images, we reshape each image into a sequence of flattened 2D patches of an arbitrary size. The resulting number of patches serves as the effective input sequence length for the Transformer. The Transformer uses constant latent vector size through all of its layers, so the patches are mapped to dimensions with a trainable linear projection. ViT refers to the output of this projection as the patch embeddings.
A simple way to perform patch division and prediction is through 2D convolution. The kernel will be of the desired patch size, introducing a stride of the patch size as well, so that there is no overlap between the pixels. On the other hand, to achieve dimensions, we will just need to set (also known as n_embd) as the desired output channels of the 2D convolution operation.
class PatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.projection = nn.Conv2d(
in_channels=config.in_channels, out_channels=config.n_embd,
kernel_size=config.patch_size, stride=config.patch_size
)
self.n_patches = (config.img_size // config.patch_size) ** 2
def forward(self, x): # (batch, in_channels, height, width)
x = self.projection(x) # (batch, n_embd, patches, patches)
x = x.flatten(2) # (batch, n_embd, n_patches) - "stacking" all patches
x = x.transpose(1, 2) # (batch, n_patches, n_embd)
return x
Adapt to Higher Resolutions
It is often beneficial to fine-tune at higher resolution than pre-training. When feeding images of higher resolution, authors keep the patch size the same, which results in a larger effective sequence length.
ViT can handle arbitrary sequence lengths (up to memory constraints), however, the pre-trained positional embeddings may no longer be meaningful. We therefore perform 2D interpolation of the pre-trained positional embeddings, according to their location in the original image. Note that this resolution adjustmen and patch extraction are the only points at which an inductive bias about the 2D structure of the images is manually injected into the model.
# example: 224:(14x14+1) -> 384: (24x24+1) (patch size 16x16)
def resize_pos_embed(posemb, ntok_new):
"""
Resize the grid of position embeddings to accommodate a new number of tokens.
Args:
- posemb (torch.Tensor): Position embeddings (batch_size, seq_length, n_embd).
- ntok_new (int): New number of tokens/seq_length.
Returns:
- torch.Tensor: Resized position embeddings (batch_size, ntok_new, n_embd).
"""
# Rescale the grid of position embeddings when loading - 24x24+1
# posemb_clas is for cls token, posemb_grid for the following tokens
posemb_clas, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
gsize_old = int(math.sqrt(len(posemb_grid))) # 14
gsize_new = int(math.sqrt(ntok_new)) # 24
# [1, 196, n_embd] -> [1, 14, 14, n_embd] -> [1, n_embd, 14, 14]
posemb_grid = posemb_grid.reshape(1, gsize_old, gsize_old, -1).permute(0, 3, 1, 2)
# [1, n_embd, 14, 14] -> [1, n_embd, 24, 24]
posemb_grid = F.interpolate(posemb_grid, size=(gsize_new, gsize_new), mode='bicubic')
# [1, n_embd, 24, 24] -> [1, 24×24, n_embd]
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gsize_new * gsize_new, -1)
# [1, 24×24+1, n_embd]
posemb = torch.cat([posemb_cls, posemb_grid], dim=1)
return posemb
Training
Similar to BERT’s [class]
token,
ViT prepends a learnable embedding to the sequence of embedded patches,
whose state at the output of the Transformer encoder serves as the image representation.
Both during pre-training and fine-tuning, a classification head is attached to the image representation
[class]
output. The classification head is implemented by a MLP with one hidden layer at
pre-training time and by a single layer at fine-tuning time.
Position embeddings are added to the patch embeddings to retain positional information. Standard learnable 1D position embeddings are used.
The authors note that ViT only performs well when trained on huge datasets with millions of images. Specifically, ResNets perform better with smaller pre-training datasets but plateau sooner than ViT, which performs better with larger pre-training. This result reinforces the intuition that the convolutional inductive bias is useful for smaller datasets, but for larger ones, learning the relevant patterns directly from data is sufficient, even beneficial.
Inductive Bias
ViT has much less image-specific inductive bias than Convolutional Neural Networks (CNNs). In CNNs, locality, two-dimensional neighborhood structure, and translation equivariance are baked into each layer throughout the whole model. In ViT, only MLP layers are local and translationally equivariant, while the self-attention layers are global.
The position embeddings at initialization time carry no information about the 2D positions of the patches and all spatial relations between the patches have to be learned from scratch.
Glossary
- : Number of Transformer encoder blocks.
- : Size of the embeddings. An embedding is a learnable representation of the words of the vocabulary.
- : Number of self-attention heads.
- num_patches: Input sequence length.
Extra: Code Repository
I have developed a repository with the code of the ViT, with the code of the full model and the training process. You can find it at microViT.
Train your own ViT model from scratch and don’t forget to star the repository if you like it!