Spaces:
Build error
Build error
darkproger
commited on
Commit
•
efb23c9
1
Parent(s):
766dac7
use st.metric for sequence logits
Browse files
app.py
CHANGED
@@ -7,7 +7,8 @@ import streamlit as st
|
|
7 |
import torch
|
8 |
from transformers import BertTokenizerFast
|
9 |
|
10 |
-
from model import BertForTokenAndSequenceJointClassification
|
|
|
11 |
|
12 |
@st.cache(allow_output_mutation=True)
|
13 |
def load_model():
|
@@ -16,22 +17,28 @@ def load_model():
|
|
16 |
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
|
17 |
revision="v0.1.0")
|
18 |
return tokenizer, model
|
19 |
-
|
20 |
-
tokenizer, model = load_model()
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
inputs = tokenizer.encode_plus(input, return_tensors="pt")
|
29 |
-
outputs = model(**inputs)
|
30 |
-
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
|
31 |
-
sequence_class = model.sequence_tags[sequence_class_index[0]]
|
32 |
-
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
|
33 |
-
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
|
34 |
-
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
|
35 |
|
36 |
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
|
37 |
|
@@ -40,7 +47,7 @@ doc = Doc(Vocab(strings=set(tokens)),
|
|
40 |
spaces=spaces,
|
41 |
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
|
42 |
|
43 |
-
labels =
|
44 |
|
45 |
label_select = st.multiselect(
|
46 |
"Tags",
|
|
|
7 |
import torch
|
8 |
from transformers import BertTokenizerFast
|
9 |
|
10 |
+
from model import BertForTokenAndSequenceJointClassification
|
11 |
+
|
12 |
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
def load_model():
|
|
|
17 |
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
|
18 |
revision="v0.1.0")
|
19 |
return tokenizer, model
|
|
|
|
|
20 |
|
21 |
+
with torch.inference_mode(True):
|
22 |
+
tokenizer, model = load_model()
|
23 |
+
|
24 |
+
st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
|
25 |
+
|
26 |
+
input = st.text_area('Input', """\
|
27 |
+
In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
|
28 |
+
""")
|
29 |
+
|
30 |
+
inputs = tokenizer.encode_plus(input, return_tensors="pt")
|
31 |
+
outputs = model(**inputs)
|
32 |
+
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
|
33 |
+
sequence_class = model.sequence_tags[sequence_class_index[0]]
|
34 |
+
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
|
35 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
|
36 |
+
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
|
37 |
|
38 |
+
columns = st.columns(len(outputs.sequence_logits.flatten()))
|
39 |
+
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
|
40 |
+
col.metric(sequence_tag, '%.2f' % logit.item())
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
|
44 |
|
|
|
47 |
spaces=spaces,
|
48 |
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
|
49 |
|
50 |
+
labels = model.token_tags[2:]
|
51 |
|
52 |
label_select = st.multiselect(
|
53 |
"Tags",
|