import torch from torch import nn from torch.nn import Module import torch.nn.functional as F class MultiHeadAttention(Module): r"""A class that implements a Multi-head Attention mechanism. Multi-head attention allows the model to focus on different positions, capturing various aspects of the input. Args: query_dim (int): The dimensionality of the query. key_dim (int): The dimensionality of the key. num_units (int): The total number of dimensions of the output. num_heads (int): The number of parallel attention layers (multi-heads). Inputs: query, and key - **query**: Tensor of shape [N, T_q, query_dim] - **key**: Tensor of shape [N, T_k, key_dim] Outputs: - An output tensor of shape [N, T_q, num_units] """ def __init__( self, query_dim: int, key_dim: int, num_units: int, num_heads: int, ): super().__init__() self.num_units = num_units self.num_heads = num_heads self.key_dim = key_dim self.W_query = nn.Linear( in_features=query_dim, out_features=num_units, bias=False, ) self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) self.W_value = nn.Linear( in_features=key_dim, out_features=num_units, bias=False, ) def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: r"""Performs the forward pass over input tensors. Args: query (torch.Tensor): The input tensor containing query vectors. It is expected to have the dimensions [N, T_q, query_dim] where N is the batch size, T_q is the sequence length of queries, and query_dim is the dimensionality of a single query vector. key (torch.Tensor): The input tensor containing key vectors. It is expected to have the dimensions [N, T_k, key_dim] where N is the batch size, T_k is the sequence length of keys, and key_dim is the dimensionality of a single key vector. Returns: torch.Tensor: The output tensor of shape [N, T_q, num_units] which represents the results of the multi-head attention mechanism applied on the provided queries and keys. """ querys = self.W_query(query) # [N, T_q, num_units] keys = self.W_key(key) # [N, T_k, num_units] values = self.W_value(key) split_size = self.num_units // self.num_heads querys = torch.stack( torch.split(querys, split_size, dim=2), dim=0, ) # [h, N, T_q, num_units/h] keys = torch.stack( torch.split(keys, split_size, dim=2), dim=0, ) # [h, N, T_k, num_units/h] values = torch.stack( torch.split(values, split_size, dim=2), dim=0, ) # [h, N, T_k, num_units/h] # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = scores / (self.key_dim**0.5) scores = F.softmax(scores, dim=3) # out = score * V out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze( 0, ) # [N, T_q, num_units]