Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -58,8 +58,8 @@ start_messages = [
|
|
58 |
# past_key_values = PastKV()
|
59 |
|
60 |
|
61 |
-
def past_kv_to_device(past_kv, device):
|
62 |
-
return tuple((torch.tensor(k).to(device).
|
63 |
|
64 |
def detach_past_kv(past_kv):
|
65 |
return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
|
@@ -83,8 +83,9 @@ def generate(text, past_key_values):
|
|
83 |
*start_messages,
|
84 |
{'role': 'user', 'content': text}
|
85 |
]
|
|
|
86 |
response = pipe(messages,
|
87 |
-
past_key_values=
|
88 |
**generate_kwargs)[0]['generated_text']
|
89 |
return response[-1]['content']
|
90 |
|
|
|
58 |
# past_key_values = PastKV()
|
59 |
|
60 |
|
61 |
+
def past_kv_to_device(past_kv, device, dtype):
|
62 |
+
return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
|
63 |
|
64 |
def detach_past_kv(past_kv):
|
65 |
return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
|
|
|
83 |
*start_messages,
|
84 |
{'role': 'user', 'content': text}
|
85 |
]
|
86 |
+
past_key_values = past_kv_to_device(past_key_values, pipe.model.device, pipe.model.dtype)
|
87 |
response = pipe(messages,
|
88 |
+
past_key_values=past_key_values,
|
89 |
**generate_kwargs)[0]['generated_text']
|
90 |
return response[-1]['content']
|
91 |
|