darkproger commited on
Commit
26dff99
1 Parent(s): 5d40602

streamlit for QCRI/PropagandaTechniquesAnalysis-en-BERT

Browse files
Files changed (3) hide show
  1. app.py +29 -0
  2. model.py +128 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import BertTokenizerFast
4
+
5
+ from model import BertForTokenAndSequenceJointClassification
6
+
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_model():
9
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
10
+ model = BertForTokenAndSequenceJointClassification.from_pretrained(
11
+ "QCRI/PropagandaTechniquesAnalysis-en-BERT",
12
+ revision="v0.1.0")
13
+ return tokenizer, model
14
+
15
+ tokenizer, model = load_model()
16
+
17
+ input = st.text_area('Input', """\
18
+ In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
19
+ """)
20
+
21
+ inputs = tokenizer.encode_plus(input, return_tensors="pt")
22
+ outputs = model(**inputs)
23
+ sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
24
+ sequence_class = model.sequence_tags[sequence_class_index[0]]
25
+ token_class_index = torch.argmax(outputs.token_logits, dim=-1)
26
+ tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
27
+ tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
28
+
29
+ st.table(list(zip(tokens, tags)))
model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = "Yifan Zhang ([email protected])"
2
+ __copyright__ = "Copyright (C) 2021, Qatar Computing Research Institute, HBKU, Doha"
3
+
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn.functional import sigmoid
10
+ from transformers import BertPreTrainedModel, BertModel
11
+ from transformers.file_utils import ModelOutput
12
+
13
+
14
+ TOKEN_TAGS = (
15
+ "<PAD>", "O",
16
+ "Name_Calling,Labeling", "Repetition", "Slogans", "Appeal_to_fear-prejudice", "Doubt",
17
+ "Exaggeration,Minimisation", "Flag-Waving", "Loaded_Language",
18
+ "Reductio_ad_hitlerum", "Bandwagon",
19
+ "Causal_Oversimplification", "Obfuscation,Intentional_Vagueness,Confusion", "Appeal_to_Authority", "Black-and-White_Fallacy",
20
+ "Thought-terminating_Cliches", "Red_Herring", "Straw_Men", "Whataboutism"
21
+ )
22
+
23
+
24
+ SEQUENCE_TAGS = ("Non-prop", "Prop")
25
+
26
+
27
+ @dataclass
28
+ class TokenAndSequenceJointClassifierOutput(ModelOutput):
29
+ loss: Optional[torch.FloatTensor] = None
30
+ token_logits: torch.FloatTensor = None
31
+ sequence_logits: torch.FloatTensor = None
32
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
33
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
34
+
35
+
36
+ class BertForTokenAndSequenceJointClassification(BertPreTrainedModel):
37
+
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ self.num_token_labels = 20
41
+ self.num_sequence_labels = 2
42
+
43
+ self.token_tags = TOKEN_TAGS
44
+ self.sequence_tags = SEQUENCE_TAGS
45
+
46
+ self.alpha = 0.9
47
+
48
+ self.bert = BertModel(config)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ self.classifier = nn.ModuleList([
51
+ nn.Linear(config.hidden_size, self.num_token_labels),
52
+ nn.Linear(config.hidden_size, self.num_sequence_labels),
53
+ ])
54
+ self.masking_gate = nn.Linear(2, 1)
55
+
56
+ self.init_weights()
57
+ self.merge_classifier_1 = nn.Linear(self.num_token_labels + self.num_sequence_labels, self.num_token_labels)
58
+
59
+ def forward(
60
+ self,
61
+ input_ids=None,
62
+ attention_mask=None,
63
+ token_type_ids=None,
64
+ position_ids=None,
65
+ head_mask=None,
66
+ inputs_embeds=None,
67
+ labels=None,
68
+ output_attentions=None,
69
+ output_hidden_states=None,
70
+ return_dict=True,
71
+ ):
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+
74
+ outputs = self.bert(
75
+ input_ids,
76
+ attention_mask=attention_mask,
77
+ token_type_ids=token_type_ids,
78
+ position_ids=position_ids,
79
+ head_mask=head_mask,
80
+ inputs_embeds=inputs_embeds,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ )
84
+
85
+ sequence_output = outputs[0]
86
+ pooler_output = outputs[1]
87
+
88
+ sequence_output = self.dropout(sequence_output)
89
+ token_logits = self.classifier[0](sequence_output)
90
+
91
+ pooler_output = self.dropout(pooler_output)
92
+ sequence_logits = self.classifier[1](pooler_output)
93
+
94
+ gate = torch.sigmoid(self.masking_gate(sequence_logits))
95
+
96
+ gates = gate.unsqueeze(1).repeat(1, token_logits.size()[1], token_logits.size()[2])
97
+
98
+ weighted_token_logits = torch.mul(gates, token_logits)
99
+
100
+ logits = [weighted_token_logits, sequence_logits]
101
+
102
+ loss = None
103
+ if labels is not None:
104
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
105
+ binary_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3932/14263]).cuda())
106
+ loss_fct = CrossEntropyLoss()
107
+ weighted_token_logits = weighted_token_logits.view(-1, weighted_token_logits.shape[-1])
108
+ sequence_logits = sequence_logits.view(-1, sequence_logits.shape[-1])
109
+
110
+ token_loss = criterion(weighted_token_logits, labels)
111
+ sequence_label = torch.LongTensor([1] if any([label > 0 for label in labels]) else [0])
112
+ sequence_loss = binary_criterion(sequence_logits, sequence_label)
113
+
114
+ loss = self.alpha*loss[0] + (1-self.alpha)*loss[1]
115
+
116
+ if not return_dict:
117
+ output = (logits,) + outputs[2:]
118
+ return ((loss,) + output) if loss is not None else output
119
+
120
+ return TokenAndSequenceJointClassifierOutput(
121
+ loss=loss,
122
+ token_logits=weighted_token_logits,
123
+ sequence_logits=sequence_logits,
124
+ hidden_states=outputs.hidden_states,
125
+ attentions=outputs.attentions,
126
+ )
127
+
128
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch