|
import gradio as gr |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
model_name = "ombhojane/mental-health-assistant" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
def generate_response(message, history): |
|
|
|
prompt = "" |
|
for user_msg, bot_msg in history: |
|
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
|
prompt += f"User: {message}\nAssistant:" |
|
|
|
|
|
response = pipe( |
|
prompt, |
|
max_length=256, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
)[0]["generated_text"] |
|
|
|
|
|
try: |
|
assistant_response = response.split("Assistant:")[-1].strip() |
|
except: |
|
assistant_response = response |
|
|
|
return assistant_response |
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
gr.Markdown( |
|
""" |
|
# 🧠 Mental Health Assistant |
|
Welcome! I'm here to provide support and guidance for your mental health concerns. |
|
While I can offer helpful insights, please remember I'm not a replacement for professional medical advice. |
|
""" |
|
) |
|
|
|
chatbot = gr.Chatbot( |
|
height=400, |
|
show_label=False, |
|
container=True, |
|
bubble_full_width=False, |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
scale=4, |
|
show_label=False, |
|
placeholder="Type your message here...", |
|
container=False |
|
) |
|
submit = gr.Button("Send", scale=1, variant="primary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"I've been feeling really anxious lately", |
|
"How can I improve my sleep habits?", |
|
"I'm having trouble focusing at work", |
|
], |
|
inputs=msg |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### Tips for best results: |
|
- Be specific about how you're feeling |
|
- Ask direct questions |
|
- Share relevant context |
|
- Take your time to explain your situation |
|
""" |
|
) |
|
|
|
|
|
submit_click = submit.click( |
|
generate_response, |
|
inputs=[msg, chatbot], |
|
outputs=[chatbot], |
|
queue=True |
|
) |
|
submit_click.then(lambda: "", None, msg, queue=False) |
|
|
|
msg.submit( |
|
generate_response, |
|
inputs=[msg, chatbot], |
|
outputs=[chatbot], |
|
queue=True |
|
).then(lambda: "", None, msg, queue=False) |
|
|
|
demo.launch() |