Depreesion / tabs /roberta_chatbot.py
vitorcalvi's picture
pre-launch
fc286f6
raw
history blame
1.91 kB
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
def create_roberta_chatbot_tab():
# Load pre-trained model and tokenizer
model_name = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Create a text classification pipeline
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
# Define response generation function
def generate_response(input_text):
# Classify the input text
result = classifier(input_text)[0]
label = result['label']
score = result['score']
# Map the classification result to a response
responses = {
"LABEL_0": "I understand you might be going through a difficult time. Remember, it's okay to seek help when you need it.",
"LABEL_1": "Your feelings are valid. Have you considered talking to a mental health professional about this?",
"LABEL_2": "Taking care of your mental health is crucial. Small steps like regular exercise and good sleep can make a big difference.",
"LABEL_3": "It sounds like you're dealing with a lot. Remember, you're not alone in this journey.",
"LABEL_4": "I hear you. Coping with mental health challenges can be tough. Have you tried any relaxation techniques like deep breathing or meditation?"
}
return responses.get(label, "I'm here to listen and support you. Could you tell me more about how you're feeling?")
# Define chatbot function for Gradio
def chatbot(message, history):
response = generate_response(message)
return response
# Create Gradio interface
iface = gr.ChatInterface(
fn=chatbot,
# title="Mental Health Support Chatbot (RoBERTa)",
)
return iface