Pocket-Callisto / app.py
prithivMLmods's picture
Update app.py
6693a45 verified
raw
history blame
3.89 kB
import argparse
import spaces
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--do_sample", action="store_true")
# This allows ignoring unrecognized arguments, e.g., from Jupyter
return parser.parse_known_args()
def load_model(model_name):
"""Load model and tokenizer from Hugging Face."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return model, tokenizer
def generate_reply(model, tokenizer, prompt, max_length, do_sample):
"""Generate text from the model given a prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# We’re returning just the final string; no streaming here
output_tokens = model.generate(
**inputs,
max_length=max_length,
do_sample=do_sample
)
return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
def main():
args, _ = get_args()
model, tokenizer = load_model(args.model)
@spaces.GPU
def respond(user_message, chat_history):
"""
Gradio expects a function that takes the last user message and the
conversation history, then returns the updated history.
chat_history is a list of (user_message, bot_reply) pairs.
"""
# Build a single text prompt from the conversation so far
prompt = ""
for (old_user_msg, old_bot_msg) in chat_history:
prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
# Add the new user query
prompt += f"User: {user_message}\nBot:"
# Generate the response
bot_message = generate_reply(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_length=args.max_length,
do_sample=args.do_sample
)
# In many cases, the model output will contain the entire prompt again,
# so we can strip that off or just let it show. If you see repeated
# text, you can try to remove the prompt prefix from bot_message.
if bot_message.startswith(prompt):
bot_message = bot_message[len(prompt):]
# Append the new user-message and bot-response to the history
chat_history.append((user_message, bot_message))
return chat_history, chat_history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
# A Chatbot component that will display the conversation
chatbot = gr.Chatbot(label="Chat")
# A text box for user input
user_input = gr.Textbox(
show_label=False,
placeholder="Type your message here and press Enter"
)
# A button to clear the conversation
clear_button = gr.Button("Clear")
# When the user hits Enter in the textbox, call 'respond'
# - Inputs: [user_input, chatbot] (the last user message and history)
# - Outputs: [chatbot, chatbot] (updates the chatbot display and history)
user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
# Define a helper function for clearing
def clear_conversation():
return [], []
# When "Clear" is clicked, reset the conversation
clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])
# Launch the Gradio app
demo.launch()
if __name__ == "__main__":
main()