dar-tau commited on
Commit
df678ec
·
verified ·
1 Parent(s): a2adaed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -58,8 +58,8 @@ start_messages = [
58
  # past_key_values = PastKV()
59
 
60
 
61
- def past_kv_to_device(past_kv, device):
62
- return tuple((torch.tensor(k).to(device).detach(), torch.tensor(v).to(device).detach()) for k, v in past_kv)
63
 
64
  def detach_past_kv(past_kv):
65
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
@@ -83,8 +83,9 @@ def generate(text, past_key_values):
83
  *start_messages,
84
  {'role': 'user', 'content': text}
85
  ]
 
86
  response = pipe(messages,
87
- past_key_values=past_kv_to_device(past_key_values, pipe.model.device),
88
  **generate_kwargs)[0]['generated_text']
89
  return response[-1]['content']
90
 
 
58
  # past_key_values = PastKV()
59
 
60
 
61
+ def past_kv_to_device(past_kv, device, dtype):
62
+ return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
63
 
64
  def detach_past_kv(past_kv):
65
  return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
 
83
  *start_messages,
84
  {'role': 'user', 'content': text}
85
  ]
86
+ past_key_values = past_kv_to_device(past_key_values, pipe.model.device, pipe.model.dtype)
87
  response = pipe(messages,
88
+ past_key_values=past_key_values,
89
  **generate_kwargs)[0]['generated_text']
90
  return response[-1]['content']
91