import torch import torch.nn as nn from .single_attention_head import AttentionHead class MultiHeadAttention(nn.Module): def __init__(self, Config): super().__init__() self.n_heads = Config.n_heads self.head_size = Config.head_size self.heads = nn.ModuleList([AttentionHead(Config) for _ in range(self.n_heads)]) self.projection = nn.Linear(Config.n_embed, Config.n_embed) self.dropout = nn.Dropout(Config.attn_dropout) def forward(self,x): x = torch.cat([h(x) for h in self.heads],dim=-1) x = self.projection(x) x = self.dropout(x) return x