tomaszki commited on
Commit
abdf1f2
·
1 Parent(s): 511e949

fixed float32 conversion

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -29,7 +29,8 @@ def get_attention_weights_and_tokens(text):
29
  tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
30
  tokenized = tokenized.to(device)
31
  output = model(**tokenized, output_attentions=True)
32
- return output.attentions.to(torch.float32), tokens
 
33
 
34
  model = load_model()
35
  tokenizer = load_tokenizer()
 
29
  tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
30
  tokenized = tokenized.to(device)
31
  output = model(**tokenized, output_attentions=True)
32
+ attentions = [attention.to(torch.float32) for attention in output.attentions]
33
+ return attentions, tokens
34
 
35
  model = load_model()
36
  tokenizer = load_tokenizer()