Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -132,31 +132,33 @@ def main():
|
|
132 |
input_text = st.text_input('Enter a sentence')
|
133 |
# put two buttons side by side in the sidebar
|
134 |
# translate_button = st.button('Translate', key='translate_button')
|
135 |
-
viz_button = st.button('Visualize Attention', key='viz_button')
|
136 |
attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
|
137 |
-
layers = st.multiselect('Select layers', list(range(
|
138 |
-
heads = st.multiselect('Select heads', list(range(
|
139 |
# allow the user to select the all the layers and heads at once to visualize
|
140 |
if st.checkbox('Select all layers'):
|
141 |
-
layers = list(range(
|
142 |
if st.checkbox('Select all heads'):
|
143 |
-
heads = list(range(
|
144 |
|
145 |
-
if
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
160 |
else:
|
161 |
st.write('Enter a sentence to visualize the attention of the model')
|
162 |
|
|
|
132 |
input_text = st.text_input('Enter a sentence')
|
133 |
# put two buttons side by side in the sidebar
|
134 |
# translate_button = st.button('Translate', key='translate_button')
|
135 |
+
# viz_button = st.button('Visualize Attention', key='viz_button')
|
136 |
attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
|
137 |
+
layers = st.multiselect('Select layers', list(range(6)))
|
138 |
+
heads = st.multiselect('Select heads', list(range(8)))
|
139 |
# allow the user to select the all the layers and heads at once to visualize
|
140 |
if st.checkbox('Select all layers'):
|
141 |
+
layers = list(range(6))
|
142 |
if st.checkbox('Select all heads'):
|
143 |
+
heads = list(range(8))
|
144 |
|
145 |
+
if input_text != '':
|
146 |
+
with st.spinner("Translating..."):
|
147 |
+
encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
|
148 |
+
max_sentence_len = len(encoder_input_tokens)
|
149 |
+
row_tokens = encoder_input_tokens
|
150 |
+
col_tokens = decoder_input_tokens
|
151 |
+
st.write('Input:', ' '.join(encoder_input_tokens))
|
152 |
+
st.write('Output:', ' '.join(decoder_input_tokens))
|
153 |
+
st.write('Translated:', output)
|
154 |
+
st.write('Attention Visualization')
|
155 |
+
with st.spinner("Visualizing Attention..."):
|
156 |
+
if attn_type == 'encoder':
|
157 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
|
158 |
+
elif attn_type == 'decoder':
|
159 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
|
160 |
+
elif attn_type == 'encoder-decoder':
|
161 |
+
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
|
162 |
else:
|
163 |
st.write('Enter a sentence to visualize the attention of the model')
|
164 |
|