jatingocodeo commited on
Commit
eccb044
·
verified ·
1 Parent(s): 6c13380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -75,30 +75,44 @@ class LlamaAttention(nn.Module):
75
  def forward(self, hidden_states, attention_mask=None):
76
  batch_size, seq_length, _ = hidden_states.size()
77
 
 
78
  q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
79
  k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
80
  v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
81
 
 
82
  if self.num_kv_heads < self.num_heads:
83
  k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
84
  v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
85
 
86
- q = q.transpose(1, 2)
87
- k = k.transpose(1, 2)
88
- v = v.transpose(1, 2)
 
89
 
90
- attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
 
 
91
 
 
92
  if attention_mask is not None:
93
- attention_scores = attention_scores + attention_mask
94
-
95
- attention_probs = F.softmax(attention_scores, dim=-1)
96
- context = torch.matmul(attention_probs, v)
 
 
 
 
 
 
97
 
98
- context = context.transpose(1, 2).contiguous()
99
- context = context.view(batch_size, seq_length, -1)
 
 
100
 
101
- return self.o_proj(context)
102
 
103
  class LlamaMLP(nn.Module):
104
  def __init__(self, config):
 
75
  def forward(self, hidden_states, attention_mask=None):
76
  batch_size, seq_length, _ = hidden_states.size()
77
 
78
+ # Project and reshape
79
  q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
80
  k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
81
  v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
82
 
83
+ # Repeat k/v heads if needed
84
  if self.num_kv_heads < self.num_heads:
85
  k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
86
  v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
87
 
88
+ # Transpose for attention
89
+ q = q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
90
+ k = k.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
91
+ v = v.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
92
 
93
+ # Calculate attention scores
94
+ scale = 1.0 / math.sqrt(self.head_dim)
95
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (batch, num_heads, seq_len, seq_len)
96
 
97
+ # Apply attention mask if provided
98
  if attention_mask is not None:
99
+ # Ensure mask is broadcastable
100
+ if attention_mask.dim() == 2:
101
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # (batch, 1, 1, seq_len)
102
+ scores = scores + attention_mask
103
+
104
+ # Apply softmax and dropout
105
+ attention_weights = F.softmax(scores, dim=-1)
106
+
107
+ # Apply attention to values
108
+ output = torch.matmul(attention_weights, v) # (batch, num_heads, seq_len, head_dim)
109
 
110
+ # Reshape and project back
111
+ output = output.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim)
112
+ output = output.view(batch_size, seq_length, -1) # (batch, seq_len, hidden_size)
113
+ output = self.o_proj(output)
114
 
115
+ return output
116
 
117
  class LlamaMLP(nn.Module):
118
  def __init__(self, config):