yasserrmd commited on
Commit
1dbadd4
·
verified ·
1 Parent(s): 57a3b13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ AutoTokenizer,
9
+ TextIteratorStreamer,
10
+ )
11
+
12
+ MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
13
+
14
+ # -------- Load model & tokenizer --------
15
+ print(f"Loading model: {MODEL_ID}")
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_ID,
19
+ device_map="auto",
20
+ torch_dtype="auto",
21
+ low_cpu_mem_usage=True,
22
+ trust_remote_code=True,
23
+ )
24
+ model.eval()
25
+
26
+ # Ensure a pad token to avoid warnings on some bases
27
+ if tokenizer.pad_token_id is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ TITLE = "SoftwareArchitecture-Instruct v1 — Chat"
31
+ DESCRIPTION = (
32
+ "An instruction-tuned LLM for **software architecture**. "
33
+ "Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. "
34
+ "Designed for technical professionals: accurate, detailed, and on-topic answers."
35
+ )
36
+
37
+ SAMPLES = [
38
+ "Explain the API Gateway pattern and when to use it.",
39
+ "CQRS vs Event Sourcing — how do they relate, and when would you combine them?",
40
+ "Design a resilient payment workflow with retries, idempotency keys, and DLQ.",
41
+ "Rate limiting strategies for a public REST API: token bucket vs sliding window.",
42
+ "Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.",
43
+ "Blue/green vs canary deployments — trade-offs and where each fits best.",
44
+ ]
45
+
46
+ def format_history_as_messages(history):
47
+ """
48
+ Convert Gradio chat history into OpenAI-style messages for apply_chat_template.
49
+ history: list of tuples (user, assistant)
50
+ """
51
+ messages = []
52
+ for (u, a) in history:
53
+ if u:
54
+ messages.append({"role": "user", "content": u})
55
+ if a:
56
+ messages.append({"role": "assistant", "content": a})
57
+ return messages
58
+
59
+ def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
60
+ """
61
+ Stream text from model.generate using TextIteratorStreamer.
62
+ """
63
+ if seed is not None and seed >= 0:
64
+ torch.manual_seed(seed)
65
+
66
+ inputs = tokenizer.apply_chat_template(
67
+ messages,
68
+ add_generation_prompt=True, # IMPORTANT for chat models
69
+ return_tensors="pt",
70
+ tokenize=True,
71
+ return_dict=True,
72
+ )
73
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
74
+
75
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
76
+ gen_kwargs = dict(
77
+ **inputs,
78
+ max_new_tokens=max_new_tokens,
79
+ temperature=float(temperature),
80
+ top_p=float(top_p),
81
+ repetition_penalty=float(repetition_penalty),
82
+ do_sample=True if temperature > 0 else False,
83
+ use_cache=True,
84
+ streamer=streamer,
85
+ )
86
+
87
+ # Run generation in a thread so we can yield from streamer
88
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
89
+ thread.start()
90
+
91
+ partial_text = ""
92
+ for new_text in streamer:
93
+ partial_text += new_text
94
+ yield partial_text
95
+
96
+ # -------- Gradio callbacks --------
97
+
98
+ def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
99
+ if not user_msg or not user_msg.strip():
100
+ return gr.update(), chat_history
101
+
102
+ # Add user turn
103
+ chat_history = chat_history + [(user_msg, None)]
104
+
105
+ # Build messages from full history
106
+ messages = format_history_as_messages(chat_history)
107
+
108
+ # Stream assistant output
109
+ stream = stream_generate(
110
+ messages=messages,
111
+ max_new_tokens=int(max_new_tokens),
112
+ temperature=float(temperature),
113
+ top_p=float(top_p),
114
+ repetition_penalty=float(repetition_penalty),
115
+ seed=int(seed) if seed is not None else None,
116
+ )
117
+
118
+ # Yield progressive updates for the last assistant turn
119
+ final_assistant_text = ""
120
+ for chunk in stream:
121
+ final_assistant_text = chunk
122
+ yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), ""
123
+
124
+ # Ensure final state returned
125
+ chat_history[-1] = (user_msg, final_assistant_text)
126
+ yield gr.update(value=chat_history), ""
127
+
128
+ def use_sample(sample, chat_history):
129
+ return sample, chat_history
130
+
131
+ def clear_chat():
132
+ return []
133
+
134
+ # -------- UI --------
135
+
136
+ CUSTOM_CSS = """
137
+ :root {
138
+ --brand: #0ea5e9; /* cyan-500 */
139
+ --ink: #0b1220;
140
+ }
141
+ .gradio-container {
142
+ font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji";
143
+ }
144
+ #title h1 {
145
+ font-weight: 700;
146
+ letter-spacing: -0.02em;
147
+ }
148
+ #desc {
149
+ opacity: 0.9;
150
+ }
151
+ footer {visibility: hidden}
152
+ """
153
+
154
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo:
155
+ with gr.Row():
156
+ with gr.Column():
157
+ gr.HTML(f"<div id='title'><h1>{TITLE}</h1></div>")
158
+ gr.Markdown(f"<div id='desc'>{DESCRIPTION}</div>", elem_id="desc")
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=4):
162
+ chat = gr.Chatbot(
163
+ label="SoftwareArchitecture-Instruct v1",
164
+ avatar_images=(None, None),
165
+ height=480,
166
+ bubble_full_width=False,
167
+ likeable=False,
168
+ sanitize_html=False,
169
+ )
170
+ with gr.Row():
171
+ user_box = gr.Textbox(
172
+ placeholder="Ask about software architecture…",
173
+ show_label=False,
174
+ lines=3,
175
+ autofocus=True,
176
+ scale=4,
177
+ )
178
+ send_btn = gr.Button("Send", variant="primary", scale=1)
179
+
180
+ with gr.Accordion("Generation Settings", open=False):
181
+ max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="Max new tokens")
182
+ temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature")
183
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
184
+ repetition_penalty = gr.Slider(1.0, 1.5, value=1.05, step=0.01, label="Repetition penalty")
185
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)")
186
+
187
+ with gr.Row():
188
+ clear_btn = gr.Button("Clear", variant="secondary")
189
+ # sample buttons
190
+ sample_dropdown = gr.Dropdown(choices=SAMPLES, label="Samples", value=None)
191
+ use_sample_btn = gr.Button("Use Sample")
192
+
193
+ with gr.Column(scale=2):
194
+ gr.Markdown("### Samples")
195
+ gr.Markdown("\n".join([f"• {s}" for s in SAMPLES]))
196
+ gr.Markdown("—\n**Tip:** Increase *Max new tokens* for longer, more complete answers.")
197
+
198
+ # Events
199
+ send_btn.click(
200
+ chat_respond,
201
+ inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
202
+ outputs=[chat, user_box],
203
+ queue=True,
204
+ show_progress=True,
205
+ )
206
+ user_box.submit(
207
+ chat_respond,
208
+ inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
209
+ outputs=[chat, user_box],
210
+ queue=True,
211
+ show_progress=True,
212
+ )
213
+ clear_btn.click(fn=clear_chat, outputs=chat)
214
+
215
+ use_sample_btn.click(use_sample, inputs=[sample_dropdown, chat], outputs=[user_box, chat])
216
+
217
+ gr.Markdown(
218
+ "—\nBuilt for engineers and architects. Base model: **LiquidAI/LFM2-1.2B** · Fine-tuned: **Software-Architecture** dataset."
219
+ )
220
+
221
+ if __name__ == "__main__":
222
+ demo.queue().launch()