Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
from huggingface_hub import hf_hub_download | |
import os | |
# Model and tokenizer details | |
model_repo = "elapt1c/ElapticAI-1a" | |
model_filename = "model.pth" # Assuming the model is saved as pytorch_model.bin, adjust if needed. Check the HF repo. | |
tokenizer_name = "microsoft/DialoGPT-medium" | |
# Device configuration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
# Load model configuration | |
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium") | |
# Initialize model from config (important to use the same architecture) | |
model = AutoModelForCausalLM.from_config(config) | |
# Download and load model weights | |
try: | |
pth_filepath = hf_hub_download(repo_id=model_repo, filename=model_filename) | |
checkpoint = torch.load(pth_filepath, map_location=device) | |
# Handle different checkpoint saving formats if needed. | |
# If your checkpoint is just the state_dict, load it directly. | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
elif 'state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['state_dict']) | |
else: | |
# Assume checkpoint is just the raw state_dict | |
model.load_state_dict(checkpoint) | |
print(f"Successfully loaded model weights from {model_repo}/{model_filename}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Please ensure the model repository and filename are correct.") | |
raise e # It's better to raise the error in a Space, so it's visible. | |
model.to(device) | |
model.eval() # Set model to evaluation mode | |
def chat_with_model(user_input, history=[]): | |
"""Chatbot function to interact with the loaded model.""" | |
history_transformer_format = history_to_transformer_format(history) | |
input_text = tokenizer.eos_token.join(history_transformer_format + [user_input]) | |
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_length=1000, # Adjust as needed | |
pad_token_id=tokenizer.eos_token_id, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Extract only the bot's last response, assuming it's after the last user input. | |
# This is a simple heuristic and might need adjustments based on training data format. | |
split_response = response.split(tokenizer.eos_token) | |
bot_response = split_response[-1].strip() | |
history.append((user_input, bot_response)) | |
return bot_response, history | |
def history_to_transformer_format(history): | |
"""Convert gradio history to a list of strings for transformer input.""" | |
history_formatted = [] | |
for user_msg, bot_msg in history: | |
history_formatted.append(user_msg) | |
history_formatted.append(bot_msg) | |
return history_formatted | |
iface = gr.ChatInterface( | |
fn=chat_with_model, | |
inputs=gr.Chatbox(placeholder="Type your message here..."), | |
outputs=gr.Chatbot(), | |
title="ElapticAI-1a Chatbot", | |
description="Simple chatbot interface for ElapticAI-1a model. Talk to the model and see its responses!", | |
examples=[ | |
["Hello"], | |
["How are you?"], | |
["Tell me a joke"] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |