Spaces:
Sleeping
Sleeping
File size: 662 Bytes
9fe7c42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
from .single_attention_head import AttentionHead
class MultiHeadAttention(nn.Module):
def __init__(self, Config):
super().__init__()
self.n_heads = Config.n_heads
self.head_size = Config.head_size
self.heads = nn.ModuleList([AttentionHead(Config) for _ in range(self.n_heads)])
self.projection = nn.Linear(Config.n_embed, Config.n_embed)
self.dropout = nn.Dropout(Config.attn_dropout)
def forward(self,x):
x = torch.cat([h(x) for h in self.heads],dim=-1)
x = self.projection(x)
x = self.dropout(x)
return x |