Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import spaces | |
import torch | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct") | |
parser.add_argument("--max_length", type=int, default=512) | |
parser.add_argument("--do_sample", action="store_true") | |
# This allows ignoring unrecognized arguments, e.g., from Jupyter | |
return parser.parse_known_args() | |
def load_model(model_name): | |
"""Load model and tokenizer from Hugging Face.""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
return model, tokenizer | |
def generate_reply(model, tokenizer, prompt, max_length, do_sample): | |
"""Generate text from the model given a prompt.""" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# We’re returning just the final string; no streaming here | |
output_tokens = model.generate( | |
**inputs, | |
max_length=max_length, | |
do_sample=do_sample | |
) | |
return tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
def main(): | |
args, _ = get_args() | |
model, tokenizer = load_model(args.model) | |
def respond(user_message, chat_history): | |
""" | |
Gradio expects a function that takes the last user message and the | |
conversation history, then returns the updated history. | |
chat_history is a list of (user_message, bot_reply) pairs. | |
""" | |
# Build a single text prompt from the conversation so far | |
prompt = "" | |
for (old_user_msg, old_bot_msg) in chat_history: | |
prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n" | |
# Add the new user query | |
prompt += f"User: {user_message}\nBot:" | |
# Generate the response | |
bot_message = generate_reply( | |
model=model, | |
tokenizer=tokenizer, | |
prompt=prompt, | |
max_length=args.max_length, | |
do_sample=args.do_sample | |
) | |
# In many cases, the model output will contain the entire prompt again, | |
# so we can strip that off or just let it show. If you see repeated | |
# text, you can try to remove the prompt prefix from bot_message. | |
if bot_message.startswith(prompt): | |
bot_message = bot_message[len(prompt):] | |
# Append the new user-message and bot-response to the history | |
chat_history.append((user_message, bot_message)) | |
return chat_history, chat_history | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>") | |
# A Chatbot component that will display the conversation | |
chatbot = gr.Chatbot(label="Chat") | |
# A text box for user input | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here and press Enter" | |
) | |
# A button to clear the conversation | |
clear_button = gr.Button("Clear") | |
# When the user hits Enter in the textbox, call 'respond' | |
# - Inputs: [user_input, chatbot] (the last user message and history) | |
# - Outputs: [chatbot, chatbot] (updates the chatbot display and history) | |
user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot]) | |
# Define a helper function for clearing | |
def clear_conversation(): | |
return [], [] | |
# When "Clear" is clicked, reset the conversation | |
clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot]) | |
# Launch the Gradio app | |
demo.launch() | |
if __name__ == "__main__": | |
main() |