Spaces:
Running
Running
File size: 4,138 Bytes
b8b1c07 adccde7 284e370 adccde7 284e370 c35f66d 284e370 adccde7 284e370 adccde7 c35f66d 284e370 c35f66d adccde7 d968e4a adccde7 284e370 adccde7 c35f66d b8b1c07 c35f66d b8b1c07 adccde7 c35f66d b8b1c07 adccde7 c35f66d adccde7 c35f66d adccde7 c35f66d adccde7 f710225 adccde7 f710225 c35f66d adccde7 b8b1c07 38f501c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModel
from huggingface_hub import hf_hub_download
import os
import torch.nn as nn
# ----- Model Definition -----
class CustomDialoGPT(nn.Module):
def __init__(self, vocab_size, n_embd=768, n_head=8, n_layer=8): # <---- FORCE n_embd, n_head, n_layer to match your model
super().__init__()
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
vocab_size=vocab_size,
n_embd=n_embd,
n_head=n_head,
n_layer=n_layer,
bos_token_id=50256,
eos_token_id=50256,
pad_token_id = 50256
)
self.transformer = AutoModelForCausalLM.from_config(config) # Use AutoModelForCausalLM here
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # Keep lm_head
def forward(self, input_ids):
transformer_outputs = self.transformer(input_ids=input_ids, output_hidden_states=True)
hidden_states = transformer_outputs.hidden_states[-1] #get last hidden state
logits = self.lm_head(hidden_states)
return logits
# Model and tokenizer details
model_repo = "elapt1c/ElapticAI-1a"
model_filename = "model.pth" # <--- CHECK FILENAME ON HF HUB, UPDATE IF NEEDED!
tokenizer_name = "microsoft/DialoGPT-medium"
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
vocab_size = len(tokenizer) # <---- Define vocab_size AFTER loading tokenizer
# Initialize model with fixed parameters to match checkpoint
n_embd=768
n_head=8
n_layer=8
model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer).to(device).eval()
# Download and load model weights
try:
pth_filepath = hf_hub_download(repo_id=model_repo, filename=model_filename)
checkpoint = torch.load(pth_filepath, map_location=device)
# Handle different checkpoint saving formats if needed.
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure the model repository and filename are correct and that the model architecture in app.py matches the checkpoint.")
raise e # It's better to raise the error in a Space, so it's visible.
model.to(device)
model.eval() # Set model to evaluation mode
def chat_with_model(user_input): # Removed history parameter for gr.Text() output
"""Chatbot function to interact with the loaded model - DYNAMIC RESPONSE."""
input_ids = tokenizer.encode(user_input, return_tensors='pt').to(device)
with torch.no_grad():
output = model.transformer.generate(
inputs=input_ids,
max_length=100,
pad_token_id=tokenizer.eos_token_id,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
bot_response = response # No need to split for gr.Text()
print("--- chat_with_model Output ---") # Debugging print
print("user_input:", user_input) # Debugging print
print("bot_response:", bot_response) # Debugging print
print("--- End of chat_with_model Output ---") # Debugging print
return bot_response # Just return bot_response for gr.Text()
iface = gr.Interface( # Changed from gr.ChatInterface to gr.Interface
fn=chat_with_model,
inputs=gr.Textbox(placeholder="Type your message here..."), # Explicitly define inputs as gr.Textbox
outputs=gr.Text(), # Changed outputs to gr.Text()
title="ElapticAI-1a Chatbot - TESTING MODEL RESPONSE", # Updated title
description="Simple chatbot interface for ElapticAI-1a model - TESTING MODEL RESPONSE" # Updated description
)
if __name__ == "__main__":
iface.launch() |