Dekode commited on
Commit
78127df
·
verified ·
1 Parent(s): 5ab3219

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -82,7 +82,7 @@ def initiate_model(config, device):
82
  model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
83
 
84
  model_filename = latest_weights_file_path(config)
85
- state = torch.load(model_filename, map_location=torch.device('cpu'))
86
  model.load_state_dict(state['model_state_dict'])
87
  return model, tokenizer_src, tokenizer_tgt
88
 
@@ -151,7 +151,12 @@ def main():
151
  st.write('Output:', ' '.join(decoder_input_tokens))
152
  st.write('Translated:', output)
153
  st.write('Attention Visualization')
154
- st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
 
 
 
 
 
155
  else:
156
  st.write('Enter a sentence to visualize the attention of the model')
157
 
 
82
  model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
83
 
84
  model_filename = latest_weights_file_path(config)
85
+ state = torch.load(model_filename)
86
  model.load_state_dict(state['model_state_dict'])
87
  return model, tokenizer_src, tokenizer_tgt
88
 
 
151
  st.write('Output:', ' '.join(decoder_input_tokens))
152
  st.write('Translated:', output)
153
  st.write('Attention Visualization')
154
+ if attn_type == 'encoder':
155
+ st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
156
+ elif attn_type == 'decoder':
157
+ st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
158
+ elif attn_type == 'encoder-decoder':
159
+ st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
160
  else:
161
  st.write('Enter a sentence to visualize the attention of the model')
162