elapt1c's picture
Update app.py
1610e28 verified
raw
history blame
3.47 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from huggingface_hub import hf_hub_download
import os
# Model and tokenizer details
model_repo = "elapt1c/ElapticAI-1a"
model_filename = "model.pth" # Assuming the model is saved as pytorch_model.bin, adjust if needed. Check the HF repo.
tokenizer_name = "microsoft/DialoGPT-medium"
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Load model configuration
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium")
# Initialize model from config (important to use the same architecture)
model = AutoModelForCausalLM.from_config(config)
# Download and load model weights
try:
pth_filepath = hf_hub_download(repo_id=model_repo, filename=model_filename)
checkpoint = torch.load(pth_filepath, map_location=device)
# Handle different checkpoint saving formats if needed.
# If your checkpoint is just the state_dict, load it directly.
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
# Assume checkpoint is just the raw state_dict
model.load_state_dict(checkpoint)
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure the model repository and filename are correct.")
raise e # It's better to raise the error in a Space, so it's visible.
model.to(device)
model.eval() # Set model to evaluation mode
def chat_with_model(user_input, history=[]):
"""Chatbot function to interact with the loaded model."""
history_transformer_format = history_to_transformer_format(history)
input_text = tokenizer.eos_token.join(history_transformer_format + [user_input])
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_length=1000, # Adjust as needed
pad_token_id=tokenizer.eos_token_id,
temperature=0.7,
top_p=0.9
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract only the bot's last response, assuming it's after the last user input.
# This is a simple heuristic and might need adjustments based on training data format.
split_response = response.split(tokenizer.eos_token)
bot_response = split_response[-1].strip()
history.append((user_input, bot_response))
return bot_response, history
def history_to_transformer_format(history):
"""Convert gradio history to a list of strings for transformer input."""
history_formatted = []
for user_msg, bot_msg in history:
history_formatted.append(user_msg)
history_formatted.append(bot_msg)
return history_formatted
iface = gr.ChatInterface(
fn=chat_with_model,
inputs=gr.Chatbox(placeholder="Type your message here..."),
outputs=gr.Chatbot(),
title="ElapticAI-1a Chatbot",
description="Simple chatbot interface for ElapticAI-1a model. Talk to the model and see its responses!",
examples=[
["Hello"],
["How are you?"],
["Tell me a joke"]
]
)
if __name__ == "__main__":
iface.launch()