Dekode commited on
Commit
a7e4364
·
verified ·
1 Parent(s): 78127df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -132,31 +132,33 @@ def main():
132
  input_text = st.text_input('Enter a sentence')
133
  # put two buttons side by side in the sidebar
134
  # translate_button = st.button('Translate', key='translate_button')
135
- viz_button = st.button('Visualize Attention', key='viz_button')
136
  attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
137
- layers = st.multiselect('Select layers', list(range(3)))
138
- heads = st.multiselect('Select heads', list(range(7)))
139
  # allow the user to select the all the layers and heads at once to visualize
140
  if st.checkbox('Select all layers'):
141
- layers = list(range(3))
142
  if st.checkbox('Select all heads'):
143
- heads = list(range(7))
144
 
145
- if viz_button and input_text != '':
146
- encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
147
- max_sentence_len = len(encoder_input_tokens)
148
- row_tokens = encoder_input_tokens
149
- col_tokens = decoder_input_tokens
150
- st.write('Input:', ' '.join(encoder_input_tokens))
151
- st.write('Output:', ' '.join(decoder_input_tokens))
152
- st.write('Translated:', output)
153
- st.write('Attention Visualization')
154
- if attn_type == 'encoder':
155
- st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
156
- elif attn_type == 'decoder':
157
- st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
158
- elif attn_type == 'encoder-decoder':
159
- st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
 
 
160
  else:
161
  st.write('Enter a sentence to visualize the attention of the model')
162
 
 
132
  input_text = st.text_input('Enter a sentence')
133
  # put two buttons side by side in the sidebar
134
  # translate_button = st.button('Translate', key='translate_button')
135
+ # viz_button = st.button('Visualize Attention', key='viz_button')
136
  attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
137
+ layers = st.multiselect('Select layers', list(range(6)))
138
+ heads = st.multiselect('Select heads', list(range(8)))
139
  # allow the user to select the all the layers and heads at once to visualize
140
  if st.checkbox('Select all layers'):
141
+ layers = list(range(6))
142
  if st.checkbox('Select all heads'):
143
+ heads = list(range(8))
144
 
145
+ if input_text != '':
146
+ with st.spinner("Translating..."):
147
+ encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
148
+ max_sentence_len = len(encoder_input_tokens)
149
+ row_tokens = encoder_input_tokens
150
+ col_tokens = decoder_input_tokens
151
+ st.write('Input:', ' '.join(encoder_input_tokens))
152
+ st.write('Output:', ' '.join(decoder_input_tokens))
153
+ st.write('Translated:', output)
154
+ st.write('Attention Visualization')
155
+ with st.spinner("Visualizing Attention..."):
156
+ if attn_type == 'encoder':
157
+ st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
158
+ elif attn_type == 'decoder':
159
+ st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
160
+ elif attn_type == 'encoder-decoder':
161
+ st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
162
  else:
163
  st.write('Enter a sentence to visualize the attention of the model')
164