TejAndrewsACC's picture
Update app.py
b39d091 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
import torch
# Load the tokenizer and model
repo_name = "nvidia/Hymba-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
# Move the model to GPU with float16 precision for efficiency
model = model.to("cuda").to(torch.float16)
# Initialize the conversation history
messages = [
{"role": "system", "content": "You are a helpful assistant."}
]
# Define stopping criteria
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings=["</s>"])])
# Chat function for Gradio interface
def chat_function(user_input):
# Add user message to the conversation history
messages.append({"role": "user", "content": user_input})
# Tokenize the conversation
tokenized_chat = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to("cuda")
# Generate a response
outputs = model.generate(
tokenized_chat["input_ids"],
max_new_tokens=256,
do_sample=False,
temperature=0.7,
use_cache=True,
stopping_criteria=stopping_criteria
)
# Decode the output response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Add the assistant's response to the conversation history
messages.append({"role": "assistant", "content": response})
return response
# Set up Gradio interface with the chatbot template
iface = gr.Interface(
fn=chat_function,
inputs=gr.inputs.Textbox(label="Your message", placeholder="Enter your message here..."),
outputs=gr.outputs.Chatbot(),
live=True,
title="Hymba Chatbot",
description="Chat with the Hymba-1.5B-Instruct model!"
)
# Launch the Gradio interface
iface.launch()