MatteoFasulo commited on
Commit
fbdaedd
·
1 Parent(s): a4b33d8

Update with main function

Browse files
Files changed (1) hide show
  1. app.py +89 -4
app.py CHANGED
@@ -1,7 +1,92 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ model_card = "microsoft/mdeberta-v3-base"
4
+ finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual"
5
 
6
+ class CustomModel(PreTrainedModel):
7
+ config_class = DebertaV2Config
8
+
9
+ def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
10
+ super().__init__(config, *args, **kwargs)
11
+ self.deberta = DebertaV2Model(config)
12
+ self.pooler = ContextPooler(config)
13
+ output_dim = self.pooler.output_dim
14
+ self.dropout = nn.Dropout(0.1)
15
+
16
+ self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
17
+
18
+ def forward(self, input_ids, positive, neutral, negative, attention_mask=None, labels=None):
19
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
20
+
21
+ encoder_layer = outputs[0]
22
+ pooled_output = self.pooler(encoder_layer)
23
+
24
+ # Sentiment features as a single tensor
25
+ sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3)
26
+
27
+ # Combine CLS embedding with sentiment features
28
+ combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
29
+
30
+ # Classification head
31
+ logits = self.classifier(self.dropout(combined_features))
32
+
33
+ return {'logits': logits}
34
+
35
+ def load_tokenizer(model_name: str):
36
+ return AutoTokenizer.from_pretrained(model_name)
37
+
38
+ # Load the pre-trained model
39
+ def load_model(model_card: str, finetuned_model: str):
40
+ tokenizer = AutoTokenizer.from_pretrained(model_card)
41
+
42
+ config = DebertaV2Config.from_pretrained(
43
+ finetuned_model,
44
+ num_labels=2,
45
+ id2label={0: 'OBJ', 1: 'SUBJ'},
46
+ label2id={'OBJ': 0, 'SUBJ': 1},
47
+ output_attentions=False,
48
+ output_hidden_states=False
49
+ )
50
+
51
+ model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(finetuned_model)
52
+
53
+ return model
54
+
55
+ def get_sentiment_values(text: str):
56
+ pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
57
+ sentiments = pipe(text)[0]
58
+ return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
59
+
60
+ def predict_subjectivity(text):
61
+ sentiment_values = get_sentiment_values(text)
62
+
63
+ model = load_model(model_card, finetuned_model)
64
+ tokenizer = load_tokenizer(model_card)
65
+
66
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
67
+
68
+ outputs = model(**inputs)
69
+ logits = outputs.get('logits')
70
+
71
+ predicted_class_idx = logits.argmax().item()
72
+ predicted_class = model.config.id2label[predicted_class_idx]
73
+
74
+ return predicted_class
75
+
76
+ demo = gr.Interface(
77
+ fn=predict_subjectivity,
78
+ inputs=gr.Textbox(
79
+ label='Input sentence',
80
+ placeholder='Enter a sentence from a news article',
81
+ info='Paste a sentence from a news article to determine if it is subjective or objective.'
82
+ ),
83
+ outputs=gr.Text(
84
+ label="Prediction",
85
+ info="Whether the sentence is subjective or objective."
86
+ ),
87
+ title='Subjectivity Detection',
88
+ description='Detect if a sentence is subjective or objective using a pre-trained model.'
89
+ theme='huggingface',
90
+ )
91
+
92
+ demo.launch()