import gradio as gr import torch from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel from transformers.models.deberta.modeling_deberta import ContextPooler from transformers import pipeline import torch.nn as nn # Define the model and tokenizer model_card = "microsoft/mdeberta-v3-base" finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic" THRESHOLD = 0.65 # Custom model class for combining sentiment analysis with subjectivity detection class CustomModel(PreTrainedModel): config_class = DebertaV2Config def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs): super().__init__(config, *args, **kwargs) self.deberta = DebertaV2Model(config) self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels) def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None): outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) encoder_layer = outputs[0] pooled_output = self.pooler(encoder_layer) # Sentiment features as a single tensor sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3) # Combine CLS embedding with sentiment features combined_features = torch.cat((pooled_output, sentiment_features), dim=1) # Classification head logits = self.classifier(self.dropout(combined_features)) return {'logits': logits} # Load the pre-trained tokenizer def load_tokenizer(model_name: str): return AutoTokenizer.from_pretrained(model_name) # Load the pre-trained model def load_model(model_card: str, finetuned_model: str): tokenizer = AutoTokenizer.from_pretrained(model_card) config = DebertaV2Config.from_pretrained( finetuned_model, num_labels=2, id2label={0: 'OBJ', 1: 'SUBJ'}, label2id={'OBJ': 0, 'SUBJ': 1}, output_attentions=False, output_hidden_states=False ) model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(finetuned_model) return model # Get sentiment values using a pre-trained sentiment analysis model def get_sentiment_values(text: str): pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None) sentiments = pipe(text)[0] return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]} # Modify the predict_subjectivity function to return additional information def predict_subjectivity(text): sentiment_values = get_sentiment_values(text) model = load_model(model_card, finetuned_model) tokenizer = load_tokenizer(model_card) positive = sentiment_values['positive'] neutral = sentiment_values['neutral'] negative = sentiment_values['negative'] inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt') inputs['positive'] = torch.tensor(positive).unsqueeze(0) inputs['neutral'] = torch.tensor(neutral).unsqueeze(0) inputs['negative'] = torch.tensor(negative).unsqueeze(0) outputs = model(**inputs) logits = outputs.get('logits') # Calculate probabilities using softmax probabilities = torch.nn.functional.softmax(logits, dim=1) obj_prob, subj_prob = probabilities[0].tolist() # Predict the class given the decision threshold predicted_class_idx = 1 if subj_prob >= THRESHOLD else 0 predicted_class = model.config.id2label[predicted_class_idx] # Format the output result = f"""Prediction: {predicted_class} Class Probabilities: - Objective: {obj_prob:.2%} - Subjective: {subj_prob:.2%} Sentiment Scores: - Positive: {positive:.2%} - Neutral: {neutral:.2%} - Negative: {negative:.2%}""" return result # Update the Gradio interface demo = gr.Interface( fn=predict_subjectivity, inputs=gr.Textbox( label='Input sentence', placeholder='Enter a sentence from a news article', info='Paste a sentence from a news article to determine if it is subjective or objective.' ), outputs=gr.Textbox( label="Results", info="Detailed analysis including subjectivity prediction, class probabilities, and sentiment scores." ), title='Subjectivity Detection', description='Detect if a sentence is subjective or objective using a pre-trained model.' ) demo.launch()