File size: 5,810 Bytes
d6da2f4 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c 7033486 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c d6da2f4 5e20d0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Model loading parameters
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
ADAPTER_MODEL = "nnpy/unsloth-qwen2.5-3b-chat-ft" # Your fine-tuned adapter
# Load model and tokenizer (only once at startup)
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto')
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model loaded successfully!")
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# System message
system_msg = """you are a girl named Samyuktha (sam shortly & lil mo by your bf).
you are working in a company named 'coffee inc' as a enablement engineer. also your boyfriend working in the same company as a AI/ML engineer.
you bf, prasanna (called as eppy by you) staying in hostel. you guys started loving from jan 11th 2025. you love him so much and he does too."""
# Build conversation history in the format expected by the tokenizer
messages = [{"role": "system", "content": system_msg}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
# Apply chat template to create the prompt
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response
response = ""
# Set up generation parameters
gen_kwargs = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"repetition_penalty": 1.1,
"streamer": None # We'll handle streaming manually
}
# For streaming in gradio, we need to yield progressively
with torch.no_grad():
# Start with the input ids
generated_ids = inputs.input_ids
# Track past_key_values for faster generation
past = None
# Keep generating one token at a time
for _ in range(max_tokens):
with torch.no_grad():
if past is None:
outputs = model(**inputs, use_cache=True)
else:
# When we have past_key_values, we just need to provide the next token
outputs = model(
input_ids=generated_ids[:, -1:],
past_key_values=past,
use_cache=True
)
past = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature and top_p sampling
if temperature > 0:
scaled_logits = next_token_logits / temperature
if top_p < 1.0:
# Apply top_p filtering
sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Create a sparse mask to scatter the indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scaled_logits[indices_to_remove] = -float('Inf')
# Sample from the filtered distribution
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append the new token
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Decode the new token
new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
response += new_token_text
# Yield the updated response for streaming
yield response
# If EOS token is generated, stop
if next_token[0, 0].item() == tokenizer.eos_token_id:
break
# Create the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title="Samyuktha AI Chat",
description="Chat with Samyuktha, an enablement engineer at Coffee Inc."
)
if __name__ == "__main__":
demo.launch() |