Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# --- Model Initialization --- | |
# Paths for tokenizer and your model checkpoint | |
tokenizer_path = "facebook/opt-1.3b" | |
model_path = "transfer_learning_therapist.pth" | |
# Load tokenizer and set pad token if needed | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load the base model and then update with your checkpoint | |
model = AutoModelForCausalLM.from_pretrained(tokenizer_path) | |
checkpoint = torch.load(model_path, map_location=device) | |
model_dict = model.state_dict() | |
pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict} | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict) | |
model.to(device) | |
model.eval() | |
# --- Inference Function --- | |
def generate_response(prompt, max_new_tokens=150, temperature=0.7, top_p=0.9, repetition_penalty=1.2): | |
"""Generates a response from your model based on the prompt.""" | |
model.eval() | |
model.config.use_cache = True | |
prompt = prompt.strip() | |
if not prompt: | |
return "Please provide a valid input." | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=repetition_penalty, | |
num_beams=1, # greedy decoding | |
no_repeat_ngram_size=3, # avoid repeated phrases | |
) | |
except Exception as e: | |
return f"Error generating response: {e}" | |
finally: | |
model.config.use_cache = False | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# If your prompt is formatted with role markers (e.g., "Therapist:"), extract only that part: | |
if "Therapist:" in full_response: | |
therapist_response = full_response.split("Therapist:")[-1].strip() | |
else: | |
therapist_response = full_response.strip() | |
return therapist_response | |
# --- Gradio Interface Function --- | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
""" | |
Build the conversation context by combining the system message and the dialogue history, | |
then generate a new response from the model. | |
""" | |
# Create a conversation prompt with your desired role labels. | |
conversation = f"System: {system_message}\n" | |
for user_msg, assistant_msg in history: | |
conversation += f"Human: {user_msg}\nTherapist: {assistant_msg}\n" | |
conversation += f"Human: {message}\nTherapist:" | |
response = generate_response( | |
conversation, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
history.append((message, response)) | |
return history, history | |
# --- Gradio ChatInterface Setup --- | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="MindfulAI Chat", | |
description="Chat with MindfulAI – an AI Therapist powered by your custom model.", | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly AI Therapist.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |