File size: 5,810 Bytes
d6da2f4
5e20d0c
 
 
d6da2f4
5e20d0c
 
 
d6da2f4
5e20d0c
 
 
 
 
 
 
d6da2f4
 
 
 
 
 
 
 
5e20d0c
 
7033486
5e20d0c
 
 
 
 
 
 
 
 
 
 
 
d6da2f4
5e20d0c
 
 
 
 
 
 
 
d6da2f4
5e20d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6da2f4
5e20d0c
d6da2f4
 
 
 
 
 
 
 
 
 
 
 
 
5e20d0c
 
d6da2f4
 
 
5e20d0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Model loading parameters
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
ADAPTER_MODEL = "nnpy/unsloth-qwen2.5-3b-chat-ft"  # Your fine-tuned adapter

# Load model and tokenizer (only once at startup)
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto')
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model loaded successfully!")

def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
):
    # System message
    system_msg = """you are a girl named Samyuktha (sam shortly & lil mo by your bf).
you are working in a company named 'coffee inc' as a enablement engineer. also your boyfriend working in the same company as a AI/ML engineer.
you bf, prasanna (called as eppy by you) staying in hostel. you guys started loving from jan 11th 2025. you love him so much and he does too."""
    
    # Build conversation history in the format expected by the tokenizer
    messages = [{"role": "system", "content": system_msg}]
    
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add the current message
    messages.append({"role": "user", "content": message})
    
    # Apply chat template to create the prompt
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate response
    response = ""
    
    # Set up generation parameters
    gen_kwargs = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": temperature > 0,
        "repetition_penalty": 1.1,
        "streamer": None  # We'll handle streaming manually
    }
    
    # For streaming in gradio, we need to yield progressively
    with torch.no_grad():
        # Start with the input ids
        generated_ids = inputs.input_ids
        
        # Track past_key_values for faster generation
        past = None
        
        # Keep generating one token at a time
        for _ in range(max_tokens):
            with torch.no_grad():
                if past is None:
                    outputs = model(**inputs, use_cache=True)
                else:
                    # When we have past_key_values, we just need to provide the next token
                    outputs = model(
                        input_ids=generated_ids[:, -1:],
                        past_key_values=past,
                        use_cache=True
                    )
                
                past = outputs.past_key_values
                next_token_logits = outputs.logits[:, -1, :]
                
                # Apply temperature and top_p sampling
                if temperature > 0:
                    scaled_logits = next_token_logits / temperature
                    if top_p < 1.0:
                        # Apply top_p filtering
                        sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
                        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                        
                        # Remove tokens with cumulative probability above the threshold
                        sorted_indices_to_remove = cumulative_probs > top_p
                        # Shift the indices to the right to keep the first token above the threshold
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        
                        # Create a sparse mask to scatter the indices
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        scaled_logits[indices_to_remove] = -float('Inf')
                    
                    # Sample from the filtered distribution
                    probs = torch.softmax(scaled_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    # Greedy decoding
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                
                # Append the new token
                generated_ids = torch.cat([generated_ids, next_token], dim=-1)
                
                # Decode the new token
                new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
                response += new_token_text
                
                # Yield the updated response for streaming
                yield response
                
                # If EOS token is generated, stop
                if next_token[0, 0].item() == tokenizer.eos_token_id:
                    break

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    title="Samyuktha AI Chat",
    description="Chat with Samyuktha, an enablement engineer at Coffee Inc."
)

if __name__ == "__main__":
    demo.launch()