prithivMLmods commited on
Commit
908cadf
·
verified ·
1 Parent(s): f16ee26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -20,7 +20,7 @@ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
21
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
22
  )
23
- model.eval() # Ensure the model is in evaluation mode.
24
 
25
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
26
 
@@ -172,7 +172,7 @@ def process_history(history: list[dict]) -> list[dict]:
172
 
173
 
174
  def generate_thread(generate_kwargs):
175
- # Empty cache to free up memory and run generation under no_grad.
176
  torch.cuda.empty_cache()
177
  with torch.no_grad():
178
  model.generate(**generate_kwargs)
@@ -190,13 +190,15 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
190
  messages.extend(process_history(history))
191
  messages.append({"role": "user", "content": process_new_user_message(message)})
192
 
193
- inputs = processor.apply_chat_template(
 
194
  messages,
195
  add_generation_prompt=True,
196
  tokenize=True,
197
  return_dict=True,
198
  return_tensors="pt",
199
- ).to(device=model.device, dtype=torch.bfloat16)
 
200
 
201
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
202
  generate_kwargs = dict(
@@ -204,7 +206,7 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
204
  streamer=streamer,
205
  max_new_tokens=max_new_tokens,
206
  )
207
- # Launch generation in a separate thread using our no_grad wrapper.
208
  t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
209
  t.start()
210
 
@@ -364,4 +366,4 @@ demo = gr.ChatInterface(
364
  )
365
 
366
  if __name__ == "__main__":
367
- demo.launch()
 
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
21
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
22
  )
23
+ model.eval() # Set model to evaluation mode.
24
 
25
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
26
 
 
172
 
173
 
174
  def generate_thread(generate_kwargs):
175
+ # Clear cache and run generation under no_grad.
176
  torch.cuda.empty_cache()
177
  with torch.no_grad():
178
  model.generate(**generate_kwargs)
 
190
  messages.extend(process_history(history))
191
  messages.append({"role": "user", "content": process_new_user_message(message)})
192
 
193
+ # Apply chat template and convert each tensor in the resulting dict.
194
+ raw_inputs = processor.apply_chat_template(
195
  messages,
196
  add_generation_prompt=True,
197
  tokenize=True,
198
  return_dict=True,
199
  return_tensors="pt",
200
+ )
201
+ inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
202
 
203
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
204
  generate_kwargs = dict(
 
206
  streamer=streamer,
207
  max_new_tokens=max_new_tokens,
208
  )
209
+ # Launch generation in a separate thread.
210
  t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
211
  t.start()
212
 
 
366
  )
367
 
368
  if __name__ == "__main__":
369
+ demo.launch(share=True)