Skip to main content

3 mins

Channel Attention with a Squeeze-and-Excitation (SE) #

Introduction #

Objectives of the lab

  • Understand how an attention mechanism can enhance a CNN.
  • Implement a Squeeze-and-Excitation (SE) block and integrate it into a small CNN.
  • Compare model performance with and without attention.

Convolutional Neural Networks (CNNs) extract feature maps, also called channels. Each channel corresponds to a different type of learned feature (edges, textures, shapes, etc.). The Squeeze-and-Excitation (SE) block introduces a simple attention mechanism: it learns which channels are most important for the current task.

Steps

  1. Train a small baseline CNN on Fashion-MNIST.
  2. Add an SE block after each convolution.
  3. Compare accuracy and observe the attention effect.
  4. Add a self-attention block and compare

Code with Dataset #

A starting code is here.

Baseline CNN (without attention) #

Write a network with two CNN layers: 32 convolutions, kernel size of 3, padding of 1. Each layer will include the convolution, max pooling, and then a RELU layer.

class CNN_Base(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(  ??? kernel_size=, padding=)
        self.conv2 = nn.Conv2d(  ??? kernel_size=, padding=)
        self.fc = nn.Linear(  ??? , 10)

    def forward(self, x):
        x = F.relu(...
        x = ...
        x = x.view(x.size(0), -1)
        return self.fc(x)

Classification with SE Block #

Image alt

SE Block Overview #

  1. Squeeze: Perform global average pooling on each channel = a vector of size C.
  2. Excitation: A small MLP learns a weight for each channel (values between 0 and 1).
  3. Recalibration: Multiply each feature map by its corresponding weight.

This is a channel-wise attention mechanism (not spatial).

The SE block is defined below as a function. The function takes the feature map and number of channels as input. GlobalAveragePooling (GAP) converts each channel to a single numerical value (Squeezing part). Ths function is called AdaptiveAvgPool2d in Pytorch. Then, two Dense blocks (fully connected block) transform the n values to n weights for each channel (Excitation). Finally, the output is computed by applying weights to the channels by multiplication.

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = ... # GAP
        self.fc = nn.Sequential(                    # Excitation
            nn.Linear(channels, reduction, bias=False),
            nn.ReLU(),
            nn.Linear( reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()       # b=batch size, c = chanel size
        y = ... # GAP
        y = ... # FC
        y = y.view(b, c, 1, 1)
        return x * y        # apply the weights to each channel

CNN with SE Attention #

Add the attention module after the two layers of convolutions.

class CNN_SE(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.se1 =  ???
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.se2 = ???
        self.fc = nn.Linear(  ??? , 10)

    def forward(self, x):
        x = ???
        x = ???
        x = ???
        x = ???
        x = x.view(x.size(0), -1)       # batch size: x.size(0)
        return self.fc(x)

Results: with and without Attention #

Expected Results

  • Baseline CNN: ~70 % (15 epoch, without attention)
  • CNN + SEBlock: ~71 % +1 (15 epoch)
  • CNN + self-attention: 67 % -3 (next question)

On small datasets, the improvement of SE is smaller than on large datasets such as ImageNet (see paper). And, attention mechanisms are more effective on large-scale data and higher-level latent representations. Attention is useful for modeling long-range dependencies that local convolutions cannot see. See the TP “Gesture transfer,” where attention can significantly improve generation.

Self Attention #

With PyTorch, a multi head attention function already exists, called MultiheadAttention. You can test a third type of network using it: keep the attention focused on channels, and do not apply attention across pixels.

attn = nn.MultiheadAttention(embed_dim=128, num_heads=8)
x = torch.rand(10, 32, 128)  # (seq_len, batch, embedding_dim)
out, weights = attn(x, x, x)  # self-attention