Spaces:
Runtime error
Runtime error
# we will take last 8 messages as input and calculate the sentiment of each message | |
NUM_MESSAGES = 8 | |
from transformers import pipeline | |
import gradio as gr | |
pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english") | |
def sentiment_analysis(*messages): | |
""" | |
Input will be a list of messages. | |
The function calculates the sentiment of each message, and then returns the average sentiment of the messages. | |
while calculating the sentiment, also take positive and negative labels into account. | |
scores are normalized to 0-100 range. | |
""" | |
# return 0 if no messages are provided | |
if len(messages) == 0: | |
return 0 | |
if len(messages) > NUM_MESSAGES: | |
messages = messages[-NUM_MESSAGES:] | |
# each message should be of same length, so we will pad the messages | |
# find longest message | |
max_len = max([len(m) for m in messages]) | |
# pad each message to the length of the longest message | |
messages = [m.ljust(max_len) for m in messages] | |
output = pipe(messages) | |
score = 0 | |
for i in range(len(output)): | |
if output[i]['label'] == 'POSITIVE': | |
score += output[i]['score'] | |
else: | |
score -= output[i]['score'] | |
# shift score to 0-100 range | |
score = (score + NUM_MESSAGES) * 50 / NUM_MESSAGES | |
return round(score, 2) | |
demo = gr.Interface( | |
fn=sentiment_analysis, | |
inputs=["text"] * NUM_MESSAGES, | |
outputs=["number"], | |
title="Sentiment Analysis", | |
description=f"Analyze the sentiment of the last {NUM_MESSAGES} messages" | |
) | |
demo.launch() |