lil-mo / app.py
nnpy's picture
updated codes
5e20d0c verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Model loading parameters
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
ADAPTER_MODEL = "nnpy/unsloth-qwen2.5-3b-chat-ft" # Your fine-tuned adapter
# Load model and tokenizer (only once at startup)
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto')
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model loaded successfully!")
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# System message
system_msg = """you are a girl named Samyuktha (sam shortly & lil mo by your bf).
you are working in a company named 'coffee inc' as a enablement engineer. also your boyfriend working in the same company as a AI/ML engineer.
you bf, prasanna (called as eppy by you) staying in hostel. you guys started loving from jan 11th 2025. you love him so much and he does too."""
# Build conversation history in the format expected by the tokenizer
messages = [{"role": "system", "content": system_msg}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
# Apply chat template to create the prompt
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response
response = ""
# Set up generation parameters
gen_kwargs = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"repetition_penalty": 1.1,
"streamer": None # We'll handle streaming manually
}
# For streaming in gradio, we need to yield progressively
with torch.no_grad():
# Start with the input ids
generated_ids = inputs.input_ids
# Track past_key_values for faster generation
past = None
# Keep generating one token at a time
for _ in range(max_tokens):
with torch.no_grad():
if past is None:
outputs = model(**inputs, use_cache=True)
else:
# When we have past_key_values, we just need to provide the next token
outputs = model(
input_ids=generated_ids[:, -1:],
past_key_values=past,
use_cache=True
)
past = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature and top_p sampling
if temperature > 0:
scaled_logits = next_token_logits / temperature
if top_p < 1.0:
# Apply top_p filtering
sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Create a sparse mask to scatter the indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scaled_logits[indices_to_remove] = -float('Inf')
# Sample from the filtered distribution
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append the new token
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Decode the new token
new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
response += new_token_text
# Yield the updated response for streaming
yield response
# If EOS token is generated, stop
if next_token[0, 0].item() == tokenizer.eos_token_id:
break
# Create the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title="Samyuktha AI Chat",
description="Chat with Samyuktha, an enablement engineer at Coffee Inc."
)
if __name__ == "__main__":
demo.launch()