A Step-by-Step Guide to building a BERT model with PyTorch (Part 2c)

Tahir Rauf
Artificial Intelligence in Plain English
4 min readNov 20, 2023

--

So far in the series, we have accomplished several tasks: In Part 1, we prepared our dataset for BERT training. In Part 2a, we prepared fixed input embeddings for the BERT. Following that, in Part 2b, we implemented multi-head attention and the feedforward block. However, we still need to implement the residual connection and the add&norm layer. In this post, we’ll be doing exactly that.

Residual Connection + Add&Norm

The purpose of a residual connection is to allow information to flow directly from the input of a layer to its output, without going through all of the intermediate computations of the layer. This can help to prevent the network from learning vanishing gradients or exploding gradients, which can make it difficult to train.

The ResidualConnection class takes two parameters: d_model, which is the dimension of the input data, and drop_out_p, which is the dropout probability. The dropout probability is used to randomly drop out some of the activations of the network during training. This can help to prevent overfitting.

The forward pass of the ResidualConnection class takes two inputs: x, which is the input tensor, and sublayer, which is a function that applies a specific operation (e.g., self-attention, feed-forward network) to the normalized input.

class ResidualConnection(nn.Module):
"""
Implements the residual connection used in the Transformer architecture.
"""
def __init__(self, d_model, drop_out_p) -> None:
"""
Initialize the ResidualConnection module.
Args:
- dropout (float): The dropout probability.
"""
super().__init__()
# Define the dropout layer with the given dropout probability
self.dropout = nn.Dropout(drop_out_p)
# Layer normalization normalizes the input data (mean=0 and variance=1)
# to stabilize and speed up training.
self.layernorm = torch.nn.LayerNorm(d_model)


def forward(self, x, sublayer):
"""
Forward pass for the residual connection.
Args:
- x (torch.Tensor): Input tensor.
- sublayer (callable): A function (or module) that applies a specific operation
(e.g., self-attention, feed-forward network) to the normalized input.
Returns:
- torch.Tensor: Output tensor after applying the sublayer and adding the residual connection.
"""
# The input x first goes through layer normalization.
# The normalized x is then passed through the sublayer.
# Dropout is applied to the output of the sublayer for regularization.
# Finally, the original input x is added back (residual connection) to the processed output.
# Note: Some implementations apply layernorm on x before applying the sublayer.
return x + self.dropout(self.layernorm(sublayer(x)))

Implement Encoder Block

Now we have implemented all the individual components of Encoder Block. We are all set to write single EncoderBlock.

class EncoderBlock(nn.Module):
"""
The Encoder is composed of a stack of encoder layers or blocks, which is analogous to stacking convolutional layers in computer vision.
Main role of encoder stack is to 'update' the input embeddings to produce representations that encode some contextual information in the sequence.
For example, the word 'apple' will be updated to be more 'company like' and less 'fruit like' if the words 'keynote' and 'phone' are close to it.
"""
def __init__(self, num_heads, d_model, dim_feedforward, drop_out_p) -> None:
super().__init__()
self.self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, drop_out_p)
self.feed_forward_block = FeedForwardBlock(d_model, dim_feedforward=dim_feedforward, drop_out_p=drop_out_p)
# Create two residual connections (one for the attention block and another for the feed-forward block)
self.residual_connections = nn.ModuleList([ResidualConnection(d_model, drop_out_p) for _ in range(2)])

def forward(self, x, mask):
# First residual connection:
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, mask))

# Second residual connection:
# The output from the previous layer goes through normalization, then the feed-forward block,
# and finally it's combined with the input to this layer to form a residual connection.
x = self.residual_connections[1](x, self.feed_forward_block)
return x

You can test it by using code below

### Testing
mask = (sample_data['bert_input'] > 0).unsqueeze(1).repeat(1, sample_data['bert_input'].size(1), 1).unsqueeze(1)
transformer_block = EncoderBlock(NUM_HEADS, D_MODEL, drop_out_p=0.1)
transformer_result = transformer_block(bert_embeding, mask)
transformer_result.size()

BERT Transformer

Now lets implement the Transformer class which will encapsulates all the components of the BERT model like the embedding layer, encoder blocks, attention masking etc. This makes it easy to instantiate and use the model in a clean way.

class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, dim_feedforward=2048, drop_out_p=0.1, num_encoder_blocks=6):
super(Transformer, self).__init__()
self.src_embed = BERTEmbedding(vocab_size, d_model, seq_len=SEQ_LEN)
encoder_blocks = []
for _ in range(num_encoder_blocks):
encoder_blocks.append(EncoderBlock(num_heads, d_model, dim_feedforward, drop_out_p))

def encode(self, x, segment_ids):
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
x = self.src_embed(x, segment_ids)
for encoder_block in self.encoder_blocks:
x = encoder_block(x, mask)
return x

You can test it by using below code

transformer = Transformer(len(tokenizer.vocab))
bert_result = transformer(sample_data['bert_input'], sample_data['segment_label'])
print(bert_result.size())

Conclusion:

We’ve implemented Transformer class which can create contextualized embeddings for the input text. In next post, we’ll use all the components and create a complete BERT language model.

References

PlainEnglish.io 🚀

Thank you for being a part of the In Plain English community! Before you go:

--

--