Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from .multi_head_attention import MultiHeadAttention | |
from .feed_forward import FeedForward | |
class TransformerBlock(nn.Module): | |
def __init__(self, Config): | |
super().__init__() | |
self.attn = MultiHeadAttention(Config) | |
self.ff = FeedForward(Config) | |
self.ln1 = nn.LayerNorm(Config.n_embed) | |
self.ln2 = nn.LayerNorm(Config.n_embed) | |
def forward(self,x): | |
x = x + self.attn(self.ln1(x)) | |
x = x + self.ff(self.ln2(x)) | |
return x |