prithivMLmods commited on
Commit
a9ad97a
·
verified ·
1 Parent(s): 908cadf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -201,11 +201,8 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
201
  inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
202
 
203
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
204
- generate_kwargs = dict(
205
- inputs=inputs,
206
- streamer=streamer,
207
- max_new_tokens=max_new_tokens,
208
- )
209
  # Launch generation in a separate thread.
210
  t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
211
  t.start()
 
201
  inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
202
 
203
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
204
+ # Unpack inputs into generate_kwargs so that each tensor is passed as a separate keyword argument.
205
+ generate_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
 
 
 
206
  # Launch generation in a separate thread.
207
  t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
208
  t.start()