import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from typing import Optional,Tuple import math import logging logger = logging.getLogger(__name__) rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese. rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0. class AttentionConfig: def __init__(self, ctx_len=100, **kwargs): self.ctx_len = ctx_len for k,v in kwargs.items(): setattr(self, k, v) ######################################################################################################## # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN ######################################################################################################## class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x, seq_len=None): if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos() self.sin_cached = emb.sin() return torch.stack([self.cos_cached, self.sin_cached]) class ContinuousRotaryEmbedding(torch.nn.Module): '''Continuous rotary position embedding''' def __init__(self, dim, sequence_scale): super().__init__() base=10000 self.sequence_scale = sequence_scale self.register_buffer('inv_freq', 1. / (base ** (torch.arange(0, dim, 2)))) def forward(self, t): t = (t + 0.5)* self.sequence_scale freqs = torch.einsum('ij,k->ijk', t, self.inv_freq) # freqs: [B, L, dim//2] emb = torch.cat((freqs, freqs), dim=-1).unsqueeze(1) # emb: [B, 1, L, dim], 1 for broadcast in head_num dim return torch.stack([emb.cos(), emb.sin()]) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), -1) @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) class MHA_rotary(nn.Module): def __init__(self, args): super().__init__() self.collect_attention_map = False self.attention_map = None assert args.encoder_dim % args.num_heads == 0 self.num_heads = args.num_heads self.head_size = args.encoder_dim // args.num_heads if args.timeshift: self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.query = nn.Linear(args.encoder_dim, args.encoder_dim) self.key = nn.Linear(args.encoder_dim, args.encoder_dim) self.value = nn.Linear(args.encoder_dim, args.encoder_dim) # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) self.output = nn.Linear(args.encoder_dim, args.encoder_dim) def forward(self, x, RoPE, key_padding_mask=None): B, T, C = x.size() if hasattr(self, 'time_shift'): x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] # cos, sin = self.rotary_emb(q, seq_len=T) cos, sin = RoPE q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding q = torch.cat((q, query_pass), dim=-1) k = torch.cat((k, key_pass), dim=-1) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) if key_padding_mask is not None: key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T) att = att.masked_fill(key_padding_mask == 0, float('-inf')) att = F.softmax(att, dim = -1) # softmax x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) x = self.output(x) if self.collect_attention_map: self.attention_map = att return x class MHA_decoder(nn.Module): def __init__(self, args): super().__init__() self.collect_attention_map = False self.attention_map = None assert args.encoder_dim % args.num_heads == 0 self.num_heads = args.num_heads self.head_size = args.decoder_dim // args.num_heads if args.timeshift: self.time_shift = nn.ZeroPad2d((0,0,1,0)) self.query = nn.Linear(args.decoder_dim, args.decoder_dim) self.key = nn.Linear(args.decoder_dim, args.decoder_dim) self.value = nn.Linear(args.decoder_dim, args.decoder_dim) # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) self.output = nn.Linear(args.decoder_dim, args.decoder_dim) def forward(self, x, memory,RoPE, key_padding_mask=None): B, T, C = x.size() _, L, M = memory.size() # print("x size: ", x.size(), 'memory size: ', memory.size()) # print('B, T, C: ', B, T, C, 'L: ', L) q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] # cos, sin = self.rotary_emb(q, seq_len=T) cos, sin = RoPE q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding q = torch.cat((q, query_pass), dim=-1) k = torch.cat((k, key_pass), dim=-1) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) if key_padding_mask is not None: key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T) att = att.masked_fill(key_padding_mask == 0, float('-inf')) att = F.softmax(att, dim = -1) # softmax x = att @ v # print("after attention vals: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) # x = self.output(x) # print("after linear: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) # cross attention: q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) k = self.key(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) v = self.value(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) # print("att size: ", att.size()) if key_padding_mask is not None: key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T) att = att.masked_fill(key_padding_mask == 0, float('-inf')) att = F.softmax(att, dim = -1) # softmax x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) # print("x deocder size: ", x.size()) x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) # print("x deocder size transposed: ", x.size()) x = self.output(x) if self.collect_attention_map: self.attention_map = att return x class GeGLU(torch.nn.Module): def __init__(self, config, layer_id, time_shift = False): super().__init__() self.layer_id = layer_id if time_shift: self.time_shift = nn.ZeroPad2d((0,0,1,0)) hidden_sz = 3 * config.n_ffn self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) def forward(self, x): B, T, C = x.size() if hasattr(self, 'time_shift'): x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) k = self.key(x) v = self.value(x) y = self.weight(F.gelu(k) * v) return y