Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList | |
import torch | |
# Load the tokenizer and model | |
repo_name = "nvidia/Hymba-1.5B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True) | |
# Move the model to GPU with float16 precision for efficiency | |
model = model.to("cuda").to(torch.float16) | |
# Initialize the conversation history | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."} | |
] | |
# Define stopping criteria | |
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings=["</s>"])]) | |
# Chat function for Gradio interface | |
def chat_function(user_input): | |
# Add user message to the conversation history | |
messages.append({"role": "user", "content": user_input}) | |
# Tokenize the conversation | |
tokenized_chat = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to("cuda") | |
# Generate a response | |
outputs = model.generate( | |
tokenized_chat["input_ids"], | |
max_new_tokens=256, | |
do_sample=False, | |
temperature=0.7, | |
use_cache=True, | |
stopping_criteria=stopping_criteria | |
) | |
# Decode the output response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Add the assistant's response to the conversation history | |
messages.append({"role": "assistant", "content": response}) | |
return response | |
# Set up Gradio interface with the chatbot template | |
iface = gr.Interface( | |
fn=chat_function, | |
inputs=gr.inputs.Textbox(label="Your message", placeholder="Enter your message here..."), | |
outputs=gr.outputs.Chatbot(), | |
live=True, | |
title="Hymba Chatbot", | |
description="Chat with the Hymba-1.5B-Instruct model!" | |
) | |
# Launch the Gradio interface | |
iface.launch() | |