File size: 1,887 Bytes
5e95f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8cb67
9704190
5e95f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

# define the labels for the mutli-classification model
class_names = ['Negative', 'Neutral', 'Positive']

# Build the Sentiment Classifier class 
class SentimentClassifier(nn.Module):
    # Constructor class 
    def __init__(self, n_classes):
        super(SentimentClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('lothritz/LuxemBERT')
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
    
    # Forward propagaion class
    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask,
            return_dict=False
        )
        #  Add a dropout layer 
        output = self.drop(pooled_output)
        return self.out(output)
# load the CNN binary classification model
model = SentimentClassifier(len(class_names))
model.load_state_dict(torch.load('./pytorch_model.bin', map_location=torch.device('cpu')))
tokenizer = BertTokenizer.from_pretrained('./')

def encode(text):
    encoded_text = tokenizer.encode_plus(
        text,
        max_length=50,
        add_special_tokens=True,
        return_token_type_ids=False,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt',
    )
    return encoded_text

def classify(text):
    encoded_comment = encode(text)
    input_ids = encoded_comment['input_ids']
    attention_mask = encoded_comment['attention_mask']

    output = model(input_ids, attention_mask)
    _, prediction = torch.max(output, dim=1)
    
    return class_names[prediction]

demo = gr.Interface(fn=classify, inputs="text", outputs="text", title="Sentiment Analyser", description="Text classifer for Luxembourgish")


demo.launch()