Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,894 Bytes
9810ea7 6ee09b1 9810ea7 9fbf2d1 9810ea7 08d30fe 9810ea7 96784fc 9810ea7 6693a45 9810ea7 6693a45 9810ea7 9fbf2d1 9810ea7 364cb51 9810ea7 364cb51 9810ea7 364cb51 9810ea7 9fbf2d1 9810ea7 364cb51 9810ea7 a2a8e37 9810ea7 96784fc 9810ea7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import argparse
import spaces
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--do_sample", action="store_true")
# This allows ignoring unrecognized arguments, e.g., from Jupyter
return parser.parse_known_args()
def load_model(model_name):
"""Load model and tokenizer from Hugging Face."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return model, tokenizer
def generate_reply(model, tokenizer, prompt, max_length, do_sample):
"""Generate text from the model given a prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# We’re returning just the final string; no streaming here
output_tokens = model.generate(
**inputs,
max_length=max_length,
do_sample=do_sample
)
return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
def main():
args, _ = get_args()
model, tokenizer = load_model(args.model)
@spaces.GPU
def respond(user_message, chat_history):
"""
Gradio expects a function that takes the last user message and the
conversation history, then returns the updated history.
chat_history is a list of (user_message, bot_reply) pairs.
"""
# Build a single text prompt from the conversation so far
prompt = ""
for (old_user_msg, old_bot_msg) in chat_history:
prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
# Add the new user query
prompt += f"User: {user_message}\nBot:"
# Generate the response
bot_message = generate_reply(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_length=args.max_length,
do_sample=args.do_sample
)
# In many cases, the model output will contain the entire prompt again,
# so we can strip that off or just let it show. If you see repeated
# text, you can try to remove the prompt prefix from bot_message.
if bot_message.startswith(prompt):
bot_message = bot_message[len(prompt):]
# Append the new user-message and bot-response to the history
chat_history.append((user_message, bot_message))
return chat_history, chat_history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
# A Chatbot component that will display the conversation
chatbot = gr.Chatbot(label="Chat")
# A text box for user input
user_input = gr.Textbox(
show_label=False,
placeholder="Type your message here and press Enter"
)
# A button to clear the conversation
clear_button = gr.Button("Clear")
# When the user hits Enter in the textbox, call 'respond'
# - Inputs: [user_input, chatbot] (the last user message and history)
# - Outputs: [chatbot, chatbot] (updates the chatbot display and history)
user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
# Define a helper function for clearing
def clear_conversation():
return [], []
# When "Clear" is clicked, reset the conversation
clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])
# Launch the Gradio app
demo.launch()
if __name__ == "__main__":
main() |