The vanilla ViT is problematic. If you take a look at the original ViT paper [1], you’ll notice that although this Deep Learning model proved to work extremely well, it requires hundreds of millions of labeled training images to achieve this.  Well, that’s a lot. 

This requirement of an enormous amount of data is definitely a problem, and thus, we need a solution for that. Touvron et al. back in December 2020 brought an idea in their research paper titled “Training data-efficient image transformers & distillation through attention” [2] to make training a ViT model to be computationally much cheaper. The authors came up with an idea where instead of training the transformer-based model from scratch, they exploited the knowledge of the existing model through distillation. With this approach, they managed to solve the ViT’s data-hungry problem while still maintaining high accuracy. What’s even more interesting is that this paper came out only two months after the original ViT!

In this article I am going to discuss the model which the authors referred to as DeiT (Data-efficient image Transformer) as well as how to implement the architecture from scratch. Since DeiT is directly derived from ViT, it is highly recommended to have prior knowledge about ViT before reading this article. You can find my previous article about it in reference [3] at the end of this post.


The Idea of DeiT

DeiT leverages the idea of knowledge distillation. In case you’re not yet familiar with the term, it is essentially a method to transfer the knowledge of a model (teacher) to another one (student) during the training phase. In this case, DeiT acts as the student whereas the teacher is RegNet, a CNN-based model. Later in the inference phase, we will completely omit the RegNet teacher and let the DeiT student make predictions on its own. 

The knowledge distillation technique allows the student model to learn more efficiently, which makes sense since it not only learns the patterns in the dataset from scratch but also benefits from the knowledge of the teacher during training. Think of it like someone learning a new subject. They could study purely from books, but it will be much more efficient if they also had a mentor to provide guidance. In this analogy, the learner acts as the student, the books are the dataset, while the mentor is the teacher. So, with this mechanism, the student essentially derives knowledge from both the dataset and the teacher simultaneously. As a result, training a student model requires much less amount of data. To better illustrate this, the original ViT needed 300 million images for training (JFT-300M dataset), while DeiT relies only on 1 million images (ImageNet-1K dataset).  That’s 300x smaller!

Technically speaking, knowledge distillation can be done without making any modifications to the student or teacher models. Rather, the changes are only made to the loss function and the training procedure. However, authors found that they can achieve more by slightly modifying the network structure, which at the same time also changing the distillation mechanism. Specifically, instead of sticking with the original ViT and apply a standard distillation process on it, they modify the architecture which they finally refer to as DeiT. It is important to know that this modification also causes the knowledge distillation mechanism to be different from the conventional one. To be exact, in ViT we only have the so-called class token, but in DeiT, we will utilize the class token itself and an additional one called distillation token. Look at the Figure 1 below to see where these two tokens are placed in the network.

Figure 1. The DeiT architecture [2].

DeiT and ViT Variants

There are three DeiT variants proposed in the paper, namely DeiT-Ti (Tiny), DeiT-S (Small) and DeiT-B (Base). Notice in Figure 2 that the largest DeiT variant (DeiT-B) is equivalent to the smallest ViT variant (ViT-B) in terms of the model size. So, this implicitly means that DeiT was indeed designed to challenge ViT by prioritizing efficiency.

Figure 2. DeiT and ViT variants [1, 2, 3].

Later in the coding part, I am going to implement the DeiT-B architecture. I will make the code as versatile as possible so that you can easily adjust the parameters if you want to implement the other variants instead. Taking a closer look at the DeiT-B row in the above table, we are going to configure the model such that it maps each image patch to a single-dimensional tensor of size 768. The elements in this tensor will then be grouped into 12 heads inside the attention layer. By doing so, every single of these attention heads will be responsible to process 64 features. Remember that the attention layer we are talking about is essentially a component of a Transformer encoder layer. In the case of DeiT-B, this layer is repeated 12 times before the tensor is eventually forwarded to the output layer. If we implement it correctly according to these configurations, the model should contain 86 million trainable parameters.

Experimental Results

There are plenty of experiments reported in the DeiT paper. Below is one of them that grabbed my attention the most.

Figure 3. The performance of different models on ImageNet-1K dataset without additional training data [2].

The above figure was obtained by training multiple models on ImageNet-1K dataset, including EfficientNet, ViT, and the DeiT itself. In fact, there are two DeiT versions displayed in the figure: DeiT and DeiT⚗ — yes with that strange symbol for the latter (called “alembic”), which basically refers to the DeiT model trained using their proposed distillation mechanism.

It is seen in the figure that the accuracy of ViT is already far behind DeiT with conventional distillation while still having the similar processing speed. The accuracy improved even further when the novel distillation mechanism was applied and the model was fine-tuned using the same images upscaled to 384×384 — hence the name DeiT-B⚗↑384. In theory, ViT should have performed better than its current result, yet in this experiment it could not unleash its full potential since it wasn’t allowed to be trained on the enormous JFT-300M dataset. And that’s just one result that proves the superiority of DeiT over ViT in a data-limited situation.

I think that was probably all the things you need to understand to implement the DeiT architecture from scratch. Don’t worry if you haven’t fully grasped the entire idea of this model yet since we will get into the details in a minute.


DeiT Implementation

As I mentioned earlier, the model we are about to implement is the DeiT-B variant. But since I also want to show you the novel knowledge distillation mechanism, I’ll specifically focus on the one referred to as DeiT-B⚗↑384. Now let’s start by importing the required modules.

# Codeblock 1
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from torchinfo import summary

As the modules have been imported, what we need to do next is to initialize some configurable parameters in the Codeblock 2 below, which are all adjusted according to the DeiT-B specifications. At the line #(1), the IMAGE_SIZE variable is set to 384 since we are about to simulate the DeiT version that accepts the upscaled images. Despite this higher resolution input, we still keep the patch size the same as when working with 224×224 images, i.e., 16×16, as written at line #(2). Next, we set EMBED_DIM to 768 (#(3)), while the NUM_HEADS and NUM_LAYERS variables are both set to 12 (#(4–5)). Authors decided to use the same FFN structure as the one used in ViT, in which the size of its hidden layer is four times larger than the embedding dimension (#(6)). The number of patches itself can be calculated using a simple formula shown at line #(7). In this case, since our image size is 384 and the patch size is 16, the value of NUM_PATCHES is going to be 576. Lastly, here I set NUM_CLASSES to 1000, simulating a classification task on ImageNet-1K dataset (#(8)).

# Codeblock 2
BATCH_SIZE   = 1
IMAGE_SIZE   = 384     #(1)
IN_CHANNELS  = 3

PATCH_SIZE   = 16      #(2)
EMBED_DIM    = 768     #(3)
NUM_HEADS    = 12      #(4)
NUM_LAYERS   = 12      #(5)
FFN_SIZE     = EMBED_DIM * 4    #(6)

NUM_PATCHES  = (IMAGE_SIZE//PATCH_SIZE) ** 2    #(7)

NUM_CLASSES  = 1000    #(8)

Treating an Image as a Sequence of Patches

When it comes to processing images using transformers, what we need to do is to treat them as a sequence of patches. Such a patching mechanism is implemented in the Patcher class below.

# Codeblock 3
class Patcher(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=IN_CHANNELS,    #(1)
                              out_channels=EMBED_DIM, 
                              kernel_size=PATCH_SIZE,     #(2)
                              stride=PATCH_SIZE)          #(3)

        self.flatten = nn.Flatten(start_dim=2)            #(4)

    def forward(self, x):
        print(f'original\t: {x.size()}')

        x = self.conv(x)        #(5)
        print(f'after conv\t: {x.size()}')

        x = self.flatten(x)     #(6)
        print(f'after flatten\t: {x.size()}')

        x = x.permute(0, 2, 1)  #(7)
        print(f'after permute\t: {x.size()}')

        return x

You can see in Codeblock 3 that we use an nn.Conv2d layer to do so (#(1)). Keep in mind that the operation done by this layer is not intended to actually perform convolution like in CNN-based models. Instead, we use it as a trick to extract the information of each patch in a non-overlapping manner, which is the reason that we set both kernel_size (#(2)) and stride (#(3)) to PATCH_SIZE (16). The operation done by this convolution layer involves the patching mechanism only — we haven’t actually put these patches into sequence just yet. In order to do so, we can simply utilize an nn.Flatten layer which I initialize at line #(4) in the above codeblock. What we need to do inside the forward() method is to pass the input tensor through the conv (#(5)) and flatten (#(6)) layers. It is also necessary to perform the permute operation afterwards because we want the patch sequence to be placed along axis 1 and the embedding dimension along axis 2 (#(7)).

Now let’s test the Patcher() class above using the following codeblock. Here I test it with a dummy tensor which the dimension is set to 1×3×384×384, simulating a single RGB image of size 384×384.

# Codeblock 4
patcher = Patcher()
x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

x = patcher(x)

And below is what the output looks like. Here I print out the tensor dimension after each step so that you can clearly see the flow inside the network. 

# Codeblock 4 Output
original      : torch.Size([1, 3, 384, 384])
after conv    : torch.Size([1, 768, 24, 24])  #(1)
after flatten : torch.Size([1, 768, 576])     #(2)
after permute : torch.Size([1, 576, 768])     #(3)

Notice at line #(1) that the spatial dimension of the tensor changed from 384×384 to 24×24. This indicates that our convolution layer successfully done the patching process. By doing so, every single pixel in the 24×24 image now represents each 16×16 patch of the input image. Furthermore, notice in the same line that the number of channels increased from 3 to EMBED_DIM (768). Later on, we will perceive this as the number of features that stores the information of a single patch. Next, we can see at line #(2) that our flatten layer successfully flattened the 24×24 tensor into a single-dimensional tensor of length 576, which means that we already got our image represented as a sequence of patch tokens. The permute operation I mentioned earlier was essentially done because in the case of time-series data PyTorch treats the axis 1 of a tensor as a sequence (#(3)).

Transformer Encoder

Now let’s put our Patcher class aside for a while since in this section we are going to implement the transformer encoder layer. This layer is directly derived from the original ViT paper which the architecture can be seen in the Figure 4 below. Take a look at Codeblock 5 to see how I implement it.

Figure 4. The Transformer encoder layer used in ViT [1].
# Codeblock 5
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.norm_0 = nn.LayerNorm(EMBED_DIM)    #(1)

        self.multihead_attention = nn.MultiheadAttention(EMBED_DIM,    #(2)
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True)

        self.norm_1 = nn.LayerNorm(EMBED_DIM)    #(3)

        self.ffn = nn.Sequential(                #(4)
            nn.Linear(in_features=EMBED_DIM, out_features=FFN_SIZE),
            nn.GELU(), 
            nn.Linear(in_features=FFN_SIZE, out_features=EMBED_DIM),
        )

    def forward(self, x):

        residual = x
        print(f'residual dim\t: {residual.size()}')

        x = self.norm_0(x)
        print(f'after norm\t: {x.size()}')

        x = self.multihead_attention(x, x, x)[0]
        print(f'after attention\t: {x.size()}')

        x = x + residual
        print(f'after addition\t: {x.size()}')

        residual = x
        print(f'residual dim\t: {residual.size()}')

        x = self.norm_1(x)
        print(f'after norm\t: {x.size()}')

        x = self.ffn(x)
        print(f'after ffn\t: {x.size()}')

        x = x + residual
        print(f'after addition\t: {x.size()}')

        return x

According to the above figure, there are four layers need to be initialized in the __init__() method, namely a multihead attention layer (#(2)), an MLP layer — which is equivalent to FFN in Figure 1 (#(4)), and two layer normalization layers (#(1,3)). I am not going to get deeper into the above code since it is exactly the same as what I explained in my previous article about ViT [4]. So, I do recommend you check that article to better understand how the Encoder class works. And additionally, if you need an in-depth explanation specifically about the attention mechanism, you can also read my previous transformer article [5] where I implemented the entire transformer architecture from scratch.

We can now just go ahead to the testing code to see how the tensor flows through the network. In the following codeblock, I assume that the input tensor x is an image that has already been processed by the Patcher block we created earlier, which is the reason why I set it to have the size of 1×576×768.

# Codeblock 6
encoder = Encoder()
x = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)

x = encoder(x)
# Codeblock 6 Output
residual dim    : torch.Size([1, 576, 768])
after norm      : torch.Size([1, 576, 768])
after attention : torch.Size([1, 576, 768])
after addition  : torch.Size([1, 576, 768])
residual dim    : torch.Size([1, 576, 768])
after norm      : torch.Size([1, 576, 768])
after ffn       : torch.Size([1, 576, 768])
after addition  : torch.Size([1, 576, 768])

According to the above result, we can see that the final output tensor dimension is exactly the same as that of the input. This property allows us to stack multiple encoder blocks without disrupting the entire network structure. Furthermore, although the shape of the tensor appears to be constant along its way to the last layer, there are actually lots of dimensionality changes happening especially inside the attention and the FFN layers. However, these changes are not printed since the processes are done internally by nn.MultiheadAttention and nn.Sequential, respectively.

The Entire DeiT Architecture

All the codes I explained in the previous sections are actually identical to those used for constructing the ViT architecture. In this section, you will finally find the ones that clearly differentiate DeiT from ViT. Let’s now focus on the layers we need to initialize in the __init__() method of the DeiT class below.

# Codeblock 7a
class DeiT(nn.Module):
    def __init__(self):
        super().__init__()

        self.patcher = Patcher()    #(1)
        
        self.class_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM))  #(2)
        self.dist_token  = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM))  #(3)
        
        trunc_normal_(self.class_token, std=.02)    #(4)
        trunc_normal_(self.dist_token, std=.02)     #(5)

        self.pos_embedding = nn.Parameter(torch.zeros(BATCH_SIZE, NUM_PATCHES+2, EMBED_DIM))  #(6)
        trunc_normal_(self.pos_embedding, std=.02)  #(7)
        
        self.encoders = nn.ModuleList([Encoder() for _ in range(NUM_LAYERS)])  #(8)
        
        self.norm_out = nn.LayerNorm(EMBED_DIM)     #(9)

        self.class_head = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES)  #(10)
        self.dist_head  = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES)  #(11)

The first component I initialized here is Patcher we created earlier (#(1)). Next, instead of only using class token, DeiT utilizes another one named distillation token. These two tokens, which in the above code are referred to as class_token (#(2)) and dist_token (#(3)), will later be appended to the patch token sequence. We set these two additional tokens to be trainable, allowing them to interact with and learn from the patch tokens later during the processing in the attention layer. Notice that I initialized these two trainable tensors using trunc_normal_() with a standard deviation of 0.02 (#(4–5)). In case you’re not yet familiar with the function, it essentially generates a truncated normal distribution, which ensures that no value lies beyond two standard deviations from the mean, avoiding the presence of extreme values for weight initialization. This approach is actually better than directly using torch.randn() since this function does not have such a value truncation mechanism.

Afterwards, we create a learnable positional embedding tensor using the same technique which I do at lines #(6) and #(7). It is important to keep in mind that this tensor will then be element-wise summed with the sequence of patch tokens that has been appended with the class and distillation tokens. Due to this reason, we need to set the length of axis 1 of this embedding tensor to NUM_PATCHES+2. Meanwhile, the transformer encoder layer is initialized inside nn.ModuleList which allows us to repeat the layer NUM_LAYERS (12) times (#(8)). The output produced by the last encoder layer in the stack will be processed with a layer norm (#(9)) before eventually being forwarded to the classification (#(10)) and distillation heads (#(11)).

Now let’s move on to the forward() method which you can see in the Codeblock 7b below.

# Codeblock 7b
    def forward(self, x):
        print(f'original\t\t: {x.size()}')
        
        x = self.patcher(x)           #(1)
        print(f'after patcher\t\t: {x.size()}')
        
        x = torch.cat([self.class_token, self.dist_token, x], dim=1)  #(2)
        print(f'after concat\t\t: {x.size()}')
        
        x = x + self.pos_embedding    #(3)
        print(f'after pos embed\t\t: {x.size()}')
        
        for i, encoder in enumerate(self.encoders):
            x = encoder(x)            #(4)
            print(f"after encoder #{i}\t: {x.size()}")

        x = self.norm_out(x)          #(5)
        print(f'after norm\t\t: {x.size()}')
        
        class_out = x[:, 0]           #(6)
        print(f'class_out\t\t: {class_out.size()}')
        
        dist_out  = x[:, 1]           #(7)
        print(f'dist_out\t\t: {dist_out.size()}')
        
        class_out = self.class_head(class_out)    #(8)
        print(f'after class_head\t: {class_out.size()}')
        
        dist_out  = self.dist_head(dist_out)       #(9)
        print(f'after dist_head\t\t: {class_out.size()}')
        
        return class_out, dist_out

After taking raw image as the input, this forward() method will process the image using the patcher layer (#(1)). As we have previously discussed, this layer is responsible to convert the image into a sequence of patches. Subsequently, we will concatenate the class and distillation tokens to it using torch.cat() (#(2)). It might be worth noting that even though the illustration in Figure 1 places the class token in the beginning of the sequence and the distillation token at the end, but the code in the official GitHub repository [6] says that the distillation token is placed right after the class token. Thus, I decided to follow this approach in our implementation. Figure 5 below illustrates what the resulting tensor looks like.

Figure 5. How class and distillation tokens are concatenated to the patch token sequence in our implementation [3].

Still with Codeblock 7b, what we need to do next is to inject the positional embedding tensor to the token sequence which the process is done at line (#(3)). We then pass the tensor through the stack of encoders using a simple loop (#(4)) and normalize the output produced by the last encoder layer (#(5)). At lines #(6) and #(7) we extract the information from the class and distillation tokens we appended earlier using a standard array slicing method. These two tokens should now contain meaningful information for classification task since they already learned the context of the image through the self-attention layers. The resulting class_out and dist_out tensors are then forwarded to two identical output layers and will undergo processing independently (#(8–9)). Since this model is intended for classification, these two output layers will produce tensors containing logits, in which every single element represents the raw prediction score of a class.

We can see the flow of the DeiT model with the following testing code, where we initially start with the raw input image (#(1)), turning it into sequence of patches (#(2)), concatenating class and distillation tokens (#(3)), and so on until eventually getting the output from both classification and distillation heads (#(4–5)).

# Codeblock 8
deit = DeiT()
x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

class_out, dist_out = deit(x)
# Codeblock 8 Output
original          : torch.Size([1, 3, 384, 384])  #(1)
after patcher     : torch.Size([1, 576, 768])     #(2)
after concat      : torch.Size([1, 578, 768])     #(3)
after pos embed   : torch.Size([1, 578, 768])
after encoder #0  : torch.Size([1, 578, 768])
after encoder #1  : torch.Size([1, 578, 768])
after encoder #2  : torch.Size([1, 578, 768])
after encoder #3  : torch.Size([1, 578, 768])
after encoder #4  : torch.Size([1, 578, 768])
after encoder #5  : torch.Size([1, 578, 768])
after encoder #6  : torch.Size([1, 578, 768])
after encoder #7  : torch.Size([1, 578, 768])
after encoder #8  : torch.Size([1, 578, 768])
after encoder #9  : torch.Size([1, 578, 768])
after encoder #10 : torch.Size([1, 578, 768])
after encoder #11 : torch.Size([1, 578, 768])
after norm        : torch.Size([1, 578, 768])
class_out         : torch.Size([1, 768])
dist_out          : torch.Size([1, 768])
after class_head  : torch.Size([1, 1000])         #(4)
after dist_head   : torch.Size([1, 1000])         #(5)

You can also run the following code if you want to see even more details of the architecture. It is seen in the resulting output that this network contains 87 million number of parameters, which is slightly higher than reported in the paper (86 million). I do acknowledge that the code I wrote above is indeed much simpler than the one in the documentation, so I might probably miss something that leads to such a difference in the number of params — please let me know if you spot any mistakes in my code!

# Codeblock 9
summary(deit, input_size=(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))
# Codeblock 9 Output
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
DeiT                                     [1, 1000]                 445,440
├─Patcher: 1-1                           [1, 576, 768]             --
│    └─Conv2d: 2-1                       [1, 768, 24, 24]          590,592
│    └─Flatten: 2-2                      [1, 768, 576]             --
├─ModuleList: 1-2                        --                        --
│    └─Encoder: 2-3                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-1               [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-2      [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-3               [1, 578, 768]             1,536
│    │    └─Sequential: 3-4              [1, 578, 768]             4,722,432
│    └─Encoder: 2-4                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-5               [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-6      [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-7               [1, 578, 768]             1,536
│    │    └─Sequential: 3-8              [1, 578, 768]             4,722,432
│    └─Encoder: 2-5                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-9               [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-10     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-11              [1, 578, 768]             1,536
│    │    └─Sequential: 3-12             [1, 578, 768]             4,722,432
│    └─Encoder: 2-6                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-13              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-14     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-15              [1, 578, 768]             1,536
│    │    └─Sequential: 3-16             [1, 578, 768]             4,722,432
│    └─Encoder: 2-7                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-17              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-18     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-19              [1, 578, 768]             1,536
│    │    └─Sequential: 3-20             [1, 578, 768]             4,722,432
│    └─Encoder: 2-8                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-21              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-22     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-23              [1, 578, 768]             1,536
│    │    └─Sequential: 3-24             [1, 578, 768]             4,722,432
│    └─Encoder: 2-9                      [1, 578, 768]             --
│    │    └─LayerNorm: 3-25              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-26     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-27              [1, 578, 768]             1,536
│    │    └─Sequential: 3-28             [1, 578, 768]             4,722,432
│    └─Encoder: 2-10                     [1, 578, 768]             --
│    │    └─LayerNorm: 3-29              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-30     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-31              [1, 578, 768]             1,536
│    │    └─Sequential: 3-32             [1, 578, 768]             4,722,432
│    └─Encoder: 2-11                     [1, 578, 768]             --
│    │    └─LayerNorm: 3-33              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-34     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-35              [1, 578, 768]             1,536
│    │    └─Sequential: 3-36             [1, 578, 768]             4,722,432
│    └─Encoder: 2-12                     [1, 578, 768]             --
│    │    └─LayerNorm: 3-37              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-38     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-39              [1, 578, 768]             1,536
│    │    └─Sequential: 3-40             [1, 578, 768]             4,722,432
│    └─Encoder: 2-13                     [1, 578, 768]             --
│    │    └─LayerNorm: 3-41              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-42     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-43              [1, 578, 768]             1,536
│    │    └─Sequential: 3-44             [1, 578, 768]             4,722,432
│    └─Encoder: 2-14                     [1, 578, 768]             --
│    │    └─LayerNorm: 3-45              [1, 578, 768]             1,536
│    │    └─MultiheadAttention: 3-46     [1, 578, 768]             2,362,368
│    │    └─LayerNorm: 3-47              [1, 578, 768]             1,536
│    │    └─Sequential: 3-48             [1, 578, 768]             4,722,432
├─LayerNorm: 1-3                         [1, 578, 768]             1,536
├─Linear: 1-4                            [1, 1000]                 769,000
├─Linear: 1-5                            [1, 1000]                 769,000
==========================================================================================
Total params: 87,630,032
Trainable params: 87,630,032
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 398.43
==========================================================================================
Input size (MB): 1.77
Forward/backward pass size (MB): 305.41
Params size (MB): 235.34
Estimated Total Size (MB): 542.52
==========================================================================================

How Classification and Distillation Heads Work

I would like to talk a little bit about the tensors produced by the two output heads. During the training phase, the output from the classification head is compared with the original ground truth (one-hot label) which the classification performance is evaluated using cross entropy loss. Meanwhile, the output from the distillation head is compared with the output produced by the teacher model, i.e., RegNet. We always perceive the output of the teacher as a truth regardless of whether its prediction is correct. And that’s essentially how knowledge is distilled from RegNet to DeiT.

There are actually two methods possible to be used to perform knowledge distillation: soft distillation and hard distillation. The former is a technique where we use the logits produced by the teacher model as is (rather than the argmaxed logits) for the label. This kind of additional ground truth is referred to as soft label. If we decided to use this technique, we should use the so-called Kullback-Leibler (KL) loss, which is suitable for comparing two logits: one from the distillation head and another one from the teacher output. On the other hand, hard distillation is a technique where the prediction made by the teacher is argmaxed prior to being compared with the output from the distillation head. In this case, the teacher output is referred to as hard label, which is similar to a typical one-hot-encoded label. Thanks to this reason, if we were to use hard label instead, we can simply use the standard cross-entropy loss for this head. Although the authors found that hard distillation performed better than soft distillation, I still think that it is worth experimenting with the two approaches if you plan to use DeiT for your upcoming project to see if this notion also applies to your case.

During the inference phase, we will no longer use the teacher model. Think of it like the student has graduated and is ready to work on its own. Despite the absence of the teacher, the output from the distillation head is still utilized. According to their GitHub documentation [6], the logits produced by both the classification and distillation heads are combined using a standard averaging mechanism before being argmaxed to obtain the final prediction.


Ending

I think that’s everything about the main idea and implementation of DeiT. It is important to note that there are still lots of things I haven’t covered in this article. So, I do recommend you read the paper [2] if you want to get even deeper into the details of this deep learning model.

Thanks for reading, I hope you learn something new today!

By the way you can access the code used in this article in the link at reference number [7].


References

[1] Alexey Dosovitskiy et al. An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale. Arxiv. https://arxiv.org/abs/2010.11929 [Accessed February 17, 2025].

[2] Hugo Touvron et al. Training Data-Efficient Image Transformers & Distillation Through Attention. Arxiv. https://arxiv.org/abs/2012.12877 [Accessed February 17, 2025].

[3] Image originally created by author.

[4] Muhammad Ardi. Paper Walkthrough: Vision Transformer (ViT). Towards Data Science. https://towardsdatascience.com/paper-walkthrough-vision-transformer-vit-c5dcf76f1a7a/ [Accessed February 17, 2025].

[5] Muhammad Ardi. Paper Walkthrough: Attention Is All You Need. Towards Data Science. https://towardsdatascience.com/paper-walkthrough-attention-is-all-you-need-80399cdc59e1/ [Accessed February 17, 2025].

[6] facebookresearch. GitHub. https://github.com/facebookresearch/deit/blob/main/models.py [Accessed February 17, 2025].

[7] MuhammadArdiPutra. Vision Transformer on a Budget. GitHub. https://github.com/MuhammadArdiPutra/medium_articles/blob/main/Vision%20Transformer%20on%20a%20Budget.ipynb [Accessed February 17, 2025].

Share.

Comments are closed.