import torch import torch.nn as nn from .multi_head_attention import MultiHeadAttention from .feed_forward import FeedForward class TransformerBlock(nn.Module): def __init__(self, Config): super().__init__() self.attn = MultiHeadAttention(Config) self.ff = FeedForward(Config) self.ln1 = nn.LayerNorm(Config.n_embed) self.ln2 = nn.LayerNorm(Config.n_embed) def forward(self,x): x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x