Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
|
20 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
21 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
22 |
)
|
23 |
-
model.eval() #
|
24 |
|
25 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
26 |
|
@@ -172,7 +172,7 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
172 |
|
173 |
|
174 |
def generate_thread(generate_kwargs):
|
175 |
-
#
|
176 |
torch.cuda.empty_cache()
|
177 |
with torch.no_grad():
|
178 |
model.generate(**generate_kwargs)
|
@@ -190,13 +190,15 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
190 |
messages.extend(process_history(history))
|
191 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
192 |
|
193 |
-
|
|
|
194 |
messages,
|
195 |
add_generation_prompt=True,
|
196 |
tokenize=True,
|
197 |
return_dict=True,
|
198 |
return_tensors="pt",
|
199 |
-
)
|
|
|
200 |
|
201 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
202 |
generate_kwargs = dict(
|
@@ -204,7 +206,7 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
204 |
streamer=streamer,
|
205 |
max_new_tokens=max_new_tokens,
|
206 |
)
|
207 |
-
# Launch generation in a separate thread
|
208 |
t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
|
209 |
t.start()
|
210 |
|
@@ -364,4 +366,4 @@ demo = gr.ChatInterface(
|
|
364 |
)
|
365 |
|
366 |
if __name__ == "__main__":
|
367 |
-
demo.launch()
|
|
|
20 |
model = Gemma3ForConditionalGeneration.from_pretrained(
|
21 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
22 |
)
|
23 |
+
model.eval() # Set model to evaluation mode.
|
24 |
|
25 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
26 |
|
|
|
172 |
|
173 |
|
174 |
def generate_thread(generate_kwargs):
|
175 |
+
# Clear cache and run generation under no_grad.
|
176 |
torch.cuda.empty_cache()
|
177 |
with torch.no_grad():
|
178 |
model.generate(**generate_kwargs)
|
|
|
190 |
messages.extend(process_history(history))
|
191 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
192 |
|
193 |
+
# Apply chat template and convert each tensor in the resulting dict.
|
194 |
+
raw_inputs = processor.apply_chat_template(
|
195 |
messages,
|
196 |
add_generation_prompt=True,
|
197 |
tokenize=True,
|
198 |
return_dict=True,
|
199 |
return_tensors="pt",
|
200 |
+
)
|
201 |
+
inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
|
202 |
|
203 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
204 |
generate_kwargs = dict(
|
|
|
206 |
streamer=streamer,
|
207 |
max_new_tokens=max_new_tokens,
|
208 |
)
|
209 |
+
# Launch generation in a separate thread.
|
210 |
t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
|
211 |
t.start()
|
212 |
|
|
|
366 |
)
|
367 |
|
368 |
if __name__ == "__main__":
|
369 |
+
demo.launch(share=True)
|