File size: 3,872 Bytes
a4b33d8
b3b327d
138ec98
 
b3b327d
 
a4b33d8
b3b327d
fbdaedd
 
a4b33d8
b3b327d
fbdaedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b327d
fbdaedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b327d
fbdaedd
 
 
 
 
b3b327d
fbdaedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b327d
fbdaedd
 
 
 
 
 
 
 
 
 
 
 
b3b327d
fbdaedd
 
 
b3b327d
fbdaedd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel
from transformers.models.deberta.modeling_deberta import ContextPooler
from transformers import pipeline
import torch.nn as nn

# Define the model and tokenizer
model_card = "microsoft/mdeberta-v3-base"
finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual"

# Custom model class for combining sentiment analysis with subjectivity detection
class CustomModel(PreTrainedModel):
    config_class = DebertaV2Config

    def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.deberta = DebertaV2Model(config)
        self.pooler = ContextPooler(config)
        output_dim = self.pooler.output_dim
        self.dropout = nn.Dropout(0.1)

        self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)

    def forward(self, input_ids, positive, neutral, negative, attention_mask=None, labels=None):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)

        encoder_layer = outputs[0]
        pooled_output = self.pooler(encoder_layer)
        
        # Sentiment features as a single tensor
        sentiment_features = torch.stack((positive, neutral, negative), dim=1)  # Shape: (batch_size, 3)
        
        # Combine CLS embedding with sentiment features
        combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
        
        # Classification head
        logits = self.classifier(self.dropout(combined_features))
        
        return {'logits': logits}

# Load the pre-trained tokenizer
def load_tokenizer(model_name: str):
    return AutoTokenizer.from_pretrained(model_name)

# Load the pre-trained model
def load_model(model_card: str, finetuned_model: str):
    tokenizer = AutoTokenizer.from_pretrained(model_card)

    config = DebertaV2Config.from_pretrained(
        finetuned_model,
        num_labels=2,
        id2label={0: 'OBJ', 1: 'SUBJ'},
        label2id={'OBJ': 0, 'SUBJ': 1},
        output_attentions=False,
        output_hidden_states=False
    )

    model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(finetuned_model)

    return model

# Get sentiment values using a pre-trained sentiment analysis model
def get_sentiment_values(text: str):
    pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
    sentiments = pipe(text)[0]
    return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}

# Predict the subjectivity of a sentence
def predict_subjectivity(text):
    sentiment_values = get_sentiment_values(text)

    model = load_model(model_card, finetuned_model)
    tokenizer = load_tokenizer(model_card)

    inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')

    outputs = model(**inputs)
    logits = outputs.get('logits')

    predicted_class_idx = logits.argmax().item()
    predicted_class = model.config.id2label[predicted_class_idx]

    return predicted_class

# Create a Gradio interface
demo = gr.Interface(
    fn=predict_subjectivity, 
    inputs=gr.Textbox(
        label='Input sentence',
        placeholder='Enter a sentence from a news article',
        info='Paste a sentence from a news article to determine if it is subjective or objective.'
    ),
    outputs=gr.Text(
        label="Prediction",
        info="Whether the sentence is subjective or objective."
    ),
    title='Subjectivity Detection',
    description='Detect if a sentence is subjective or objective using a pre-trained model.',
    theme='huggingface',
)

# Launch the interface
demo.launch()