MatteoFasulo commited on
Commit
e91e5d5
·
1 Parent(s): 7e9f859

Enhance subjectivity prediction with detailed output and update Gradio interface

Browse files
Files changed (1) hide show
  1. app.py +36 -13
app.py CHANGED
@@ -7,7 +7,8 @@ import torch.nn as nn
7
 
8
  # Define the model and tokenizer
9
  model_card = "microsoft/mdeberta-v3-base"
10
- finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual"
 
11
 
12
  # Custom model class for combining sentiment analysis with subjectivity detection
13
  class CustomModel(PreTrainedModel):
@@ -22,7 +23,7 @@ class CustomModel(PreTrainedModel):
22
 
23
  self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
24
 
25
- def forward(self, input_ids, positive, neutral, negative, attention_mask=None, labels=None):
26
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
27
 
28
  encoder_layer = outputs[0]
@@ -66,24 +67,48 @@ def get_sentiment_values(text: str):
66
  sentiments = pipe(text)[0]
67
  return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
68
 
69
- # Predict the subjectivity of a sentence
70
  def predict_subjectivity(text):
71
  sentiment_values = get_sentiment_values(text)
72
 
73
  model = load_model(model_card, finetuned_model)
74
  tokenizer = load_tokenizer(model_card)
75
 
 
 
 
 
76
  inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
 
 
 
77
 
78
  outputs = model(**inputs)
79
  logits = outputs.get('logits')
80
 
81
- predicted_class_idx = logits.argmax().item()
 
 
 
 
 
82
  predicted_class = model.config.id2label[predicted_class_idx]
83
 
84
- return predicted_class
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Create a Gradio interface
87
  demo = gr.Interface(
88
  fn=predict_subjectivity,
89
  inputs=gr.Textbox(
@@ -91,14 +116,12 @@ demo = gr.Interface(
91
  placeholder='Enter a sentence from a news article',
92
  info='Paste a sentence from a news article to determine if it is subjective or objective.'
93
  ),
94
- outputs=gr.Text(
95
- label="Prediction",
96
- info="Whether the sentence is subjective or objective."
97
  ),
98
  title='Subjectivity Detection',
99
- description='Detect if a sentence is subjective or objective using a pre-trained model.',
100
- theme='huggingface',
101
  )
102
 
103
- # Launch the interface
104
- demo.launch(share=True)
 
7
 
8
  # Define the model and tokenizer
9
  model_card = "microsoft/mdeberta-v3-base"
10
+ finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic"
11
+ THRESHOLD = 0.65
12
 
13
  # Custom model class for combining sentiment analysis with subjectivity detection
14
  class CustomModel(PreTrainedModel):
 
23
 
24
  self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
25
 
26
+ def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
27
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
28
 
29
  encoder_layer = outputs[0]
 
67
  sentiments = pipe(text)[0]
68
  return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
69
 
70
+ # Modify the predict_subjectivity function to return additional information
71
  def predict_subjectivity(text):
72
  sentiment_values = get_sentiment_values(text)
73
 
74
  model = load_model(model_card, finetuned_model)
75
  tokenizer = load_tokenizer(model_card)
76
 
77
+ positive = sentiment_values['positive']
78
+ neutral = sentiment_values['neutral']
79
+ negative = sentiment_values['negative']
80
+
81
  inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
82
+ inputs['positive'] = torch.tensor(positive).unsqueeze(0)
83
+ inputs['neutral'] = torch.tensor(neutral).unsqueeze(0)
84
+ inputs['negative'] = torch.tensor(negative).unsqueeze(0)
85
 
86
  outputs = model(**inputs)
87
  logits = outputs.get('logits')
88
 
89
+ # Calculate probabilities using softmax
90
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
91
+ obj_prob, subj_prob = probabilities[0].tolist()
92
+
93
+ # Predict the class given the decision threshold
94
+ predicted_class_idx = 1 if subj_prob >= THRESHOLD else 0
95
  predicted_class = model.config.id2label[predicted_class_idx]
96
 
97
+ # Format the output
98
+ result = f"""Prediction: {predicted_class}
99
+
100
+ Class Probabilities:
101
+ - Objective: {obj_prob:.2%}
102
+ - Subjective: {subj_prob:.2%}
103
+
104
+ Sentiment Scores:
105
+ - Positive: {positive:.2%}
106
+ - Neutral: {neutral:.2%}
107
+ - Negative: {negative:.2%}"""
108
+
109
+ return result
110
 
111
+ # Update the Gradio interface
112
  demo = gr.Interface(
113
  fn=predict_subjectivity,
114
  inputs=gr.Textbox(
 
116
  placeholder='Enter a sentence from a news article',
117
  info='Paste a sentence from a news article to determine if it is subjective or objective.'
118
  ),
119
+ outputs=gr.Textbox(
120
+ label="Results",
121
+ info="Detailed analysis including subjectivity prediction, class probabilities, and sentiment scores."
122
  ),
123
  title='Subjectivity Detection',
124
+ description='Detect if a sentence is subjective or objective using a pre-trained model.'
 
125
  )
126
 
127
+ demo.launch()