File size: 799 Bytes
b72368a
 
 
10210a4
 
d21a6a7
 
b72368a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer

chkpt='distilbert-base-uncased-finetuned-sst-2-english'
model=AutoModelForSequenceClassification.from_pretrained(chkpt)
tokenizer=AutoTokenizer.from_pretrained(chkpt)
# tokenizer=AutoTokenizer.from_pretrained('sentiment_classifier/')

def classify_sentiment(texts,model=model,tokenizer=tokenizer):
    """
        user will pass texts separated by comma
    """
    try:
        texts=texts.split(',')
    except:
        pass

    input = tokenizer(texts, padding=True, truncation=True,
                      return_tensors="pt")
    logits = model(**input)['logits'].softmax(dim=1)
    logits = torch.argmax(logits, dim=1)
    output = ['Positive' if i == 1 else 'Negative' for i in logits]
    return output