nielsr HF staff commited on
Commit
85920f1
·
1 Parent(s): 806a35d

Add print statements

Browse files
Files changed (1) hide show
  1. visual.py +8 -5
visual.py CHANGED
@@ -31,15 +31,18 @@ class Attention(nn.Module):
31
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
 
34
- def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
  B, L, _ = x.shape
36
  qkv = self.query_key_value(x)
37
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
  q, k, v = qkv[0], qkv[1], qkv[2]
39
 
40
- out = xops.memory_efficient_attention(
41
- q, k, v, scale=self.scale,
42
- )
 
 
 
43
  output = self.dense(out.view(B, L, -1))
44
  output = self.output_dropout(output)
45
  return output
@@ -80,7 +83,7 @@ class TransformerLayer(nn.Module):
80
  if print_values:
81
  print("Hidden states before attention:", attention_input[0, :3, :3])
82
 
83
- attention_output = self.attention(attention_input)
84
 
85
  if print_values:
86
  print("Hidden states after attention:", attention_output[0, :3, :3])
 
31
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
 
34
+ def forward(self, x: "tensor(B, L, D)", print_values=False) -> "tensor(B, L, D)":
35
  B, L, _ = x.shape
36
  qkv = self.query_key_value(x)
37
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
  q, k, v = qkv[0], qkv[1], qkv[2]
39
 
40
+ # out = xops.memory_efficient_attention(
41
+ # q, k, v, scale=self.scale,
42
+ # )
43
+
44
+ out = self.attention(q, k, v)
45
+
46
  output = self.dense(out.view(B, L, -1))
47
  output = self.output_dropout(output)
48
  return output
 
83
  if print_values:
84
  print("Hidden states before attention:", attention_input[0, :3, :3])
85
 
86
+ attention_output = self.attention(attention_input, print_values=print_values)
87
 
88
  if print_values:
89
  print("Hidden states after attention:", attention_output[0, :3, :3])