Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
def create_roberta_chatbot_tab(): | |
# Load pre-trained model and tokenizer | |
model_name = "roberta-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Create a text classification pipeline | |
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
# Define response generation function | |
def generate_response(input_text): | |
# Classify the input text | |
result = classifier(input_text)[0] | |
label = result['label'] | |
score = result['score'] | |
# Map the classification result to a response | |
responses = { | |
"LABEL_0": "I understand you might be going through a difficult time. Remember, it's okay to seek help when you need it.", | |
"LABEL_1": "Your feelings are valid. Have you considered talking to a mental health professional about this?", | |
"LABEL_2": "Taking care of your mental health is crucial. Small steps like regular exercise and good sleep can make a big difference.", | |
"LABEL_3": "It sounds like you're dealing with a lot. Remember, you're not alone in this journey.", | |
"LABEL_4": "I hear you. Coping with mental health challenges can be tough. Have you tried any relaxation techniques like deep breathing or meditation?" | |
} | |
return responses.get(label, "I'm here to listen and support you. Could you tell me more about how you're feeling?") | |
# Define chatbot function for Gradio | |
def chatbot(message, history): | |
response = generate_response(message) | |
return response | |
# Create Gradio interface | |
iface = gr.ChatInterface( | |
fn=chatbot, | |
# title="Mental Health Support Chatbot (RoBERTa)", | |
) | |
return iface |