|
from fastapi import FastAPI |
|
import gradio as gr |
|
|
|
from transformers import pipeline |
|
from gradio.components import Textbox |
|
|
|
app = FastAPI() |
|
|
|
|
|
distilbert_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") |
|
label_map = {"POSITIVE":"OTHER", "NEGATIVE":"SENSITIVE"} |
|
|
|
input1 = Textbox(lines=2, placeholder="Type your text here...") |
|
|
|
@app.get("/") |
|
async def root(): |
|
def predict_sentiment(text): |
|
""" |
|
Predicts the sentiment of the input text using DistilBERT. |
|
:param text: str, input text to analyze. |
|
:return: str, predicted sentiment and confidence score. |
|
""" |
|
result = distilbert_pipeline(text)[0] |
|
label = label_map[result['label']] |
|
score = result['score'] |
|
return f"TAG: {label}, Confidence: {score:.2f}" |
|
|
|
|
|
text_input = gr.Interface(fn=predict_sentiment, |
|
inputs=input1, |
|
outputs="text", |
|
title="Talk2Loop Sensitive statement tags", |
|
description="This model predicts the sensitivity of the input text. Enter a sentence to see if it's sensitive or not.") |
|
|
|
return text_input.launch(share=True, host="0.0.0.0", port=8000) |
|
|
|
|
|
app = gr.mount_gradio_app(app, text_input, path="/") |
|
|
|
|