Martijn van Beers commited on
Commit
4b1b415
·
1 Parent(s): 9e7d7f8

Try gradio explanations

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -156,7 +156,7 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
156
  output, expl = generate_relevance(
157
  model, input_ids, attention_mask, index=index, start_layer=start_layer
158
  )
159
- print(output.shape, expl.shape)
160
  # normalize scores
161
  scaler = PyTMinMaxScalerVectorized()
162
 
@@ -177,23 +177,34 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
177
  tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
178
  1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
179
  ]
180
- print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
181
- vis_data_records.append(
182
- visualization.VisualizationDataRecord(
183
- nrm,
184
- output[record][classification],
185
- classification,
186
- classification,
187
- index,
188
- 1,
189
- tokens,
190
- 1,
191
- )
192
- )
193
- return visualize_text(vis_data_records)
194
-
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- def run(input_text):
197
  text_batch = [input_text]
198
  encoding = tokenizer(text_batch, return_tensors="pt")
199
  input_ids = encoding["input_ids"].to(device)
@@ -202,14 +213,12 @@ def run(input_text):
202
  # true class is positive - 1
203
  true_class = 1
204
 
205
- html = show_explanation(model, input_ids, attention_mask)
206
- return html
207
-
208
 
209
- iface = gradio.Interface(
210
- fn=run,
211
  inputs="text",
212
- outputs="html",
213
  title="RoBERTa Explanability",
214
  description="Quick demo of a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.",
215
  examples=[
@@ -220,5 +229,6 @@ iface = gradio.Interface(
220
  "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
221
  ],
222
  ],
 
223
  )
224
- iface.launch()
 
156
  output, expl = generate_relevance(
157
  model, input_ids, attention_mask, index=index, start_layer=start_layer
158
  )
159
+ #print(output.shape, expl.shape)
160
  # normalize scores
161
  scaler = PyTMinMaxScalerVectorized()
162
 
 
177
  tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
178
  1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
179
  ]
180
+ vis_data_records.append(list(zip(tokens, nrm.tolist())))
181
+ #print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
182
+ # vis_data_records.append(
183
+ # visualization.VisualizationDataRecord(
184
+ # nrm,
185
+ # output[record][classification],
186
+ # classification,
187
+ # classification,
188
+ # index,
189
+ # 1,
190
+ # tokens,
191
+ # 1,
192
+ # )
193
+ # )
194
+ # return visualize_text(vis_data_records)
195
+ return vis_data_records
196
+
197
+
198
+ def sentence_sentiment(input_text):
199
+ text_batch = [input_text]
200
+ encoding = tokenizer(text_batch, return_tensors="pt")
201
+ input_ids = encoding["input_ids"].to(device)
202
+ attention_mask = encoding["attention_mask"].to(device)
203
+ output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
204
+ index = output.argmax(axis=-1).item()
205
+ return classification[index]
206
 
207
+ def sentiment_explanation_hila(input_text):
208
  text_batch = [input_text]
209
  encoding = tokenizer(text_batch, return_tensors="pt")
210
  input_ids = encoding["input_ids"].to(device)
 
213
  # true class is positive - 1
214
  true_class = 1
215
 
216
+ return show_explanation(model, input_ids, attention_mask)
 
 
217
 
218
+ hila = gradio.Interface(
219
+ fn=sentence_sentiment,
220
  inputs="text",
221
+ outputs="label",
222
  title="RoBERTa Explanability",
223
  description="Quick demo of a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.",
224
  examples=[
 
229
  "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
230
  ],
231
  ],
232
+ interpretation=sentiment_explanation_hila
233
  )
234
+ hila.launch()