Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -82,7 +82,7 @@ def initiate_model(config, device):
|
|
82 |
model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
|
83 |
|
84 |
model_filename = latest_weights_file_path(config)
|
85 |
-
state = torch.load(model_filename
|
86 |
model.load_state_dict(state['model_state_dict'])
|
87 |
return model, tokenizer_src, tokenizer_tgt
|
88 |
|
@@ -151,7 +151,12 @@ def main():
|
|
151 |
st.write('Output:', ' '.join(decoder_input_tokens))
|
152 |
st.write('Translated:', output)
|
153 |
st.write('Attention Visualization')
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
else:
|
156 |
st.write('Enter a sentence to visualize the attention of the model')
|
157 |
|
|
|
82 |
model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
|
83 |
|
84 |
model_filename = latest_weights_file_path(config)
|
85 |
+
state = torch.load(model_filename)
|
86 |
model.load_state_dict(state['model_state_dict'])
|
87 |
return model, tokenizer_src, tokenizer_tgt
|
88 |
|
|
|
151 |
st.write('Output:', ' '.join(decoder_input_tokens))
|
152 |
st.write('Translated:', output)
|
153 |
st.write('Attention Visualization')
|
154 |
+
if attn_type == 'encoder':
|
155 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
|
156 |
+
elif attn_type == 'decoder':
|
157 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
|
158 |
+
elif attn_type == 'encoder-decoder':
|
159 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
|
160 |
else:
|
161 |
st.write('Enter a sentence to visualize the attention of the model')
|
162 |
|