Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -57,12 +57,14 @@ start_messages = [
|
|
57 |
|
58 |
# past_key_values = PastKV()
|
59 |
|
|
|
60 |
def past_kv_to_device(past_kv, device):
|
61 |
-
return tuple((k.to(device).detach(), v.to(device).detach()) for k, v in past_kv)
|
62 |
|
63 |
def detach_past_kv(past_kv):
|
64 |
return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
|
65 |
|
|
|
66 |
@spaces.GPU
|
67 |
def set_past_key_values():
|
68 |
model, tokenizer = pipe.model, pipe.tokenizer
|
@@ -91,6 +93,6 @@ if __name__ == "__main__":
|
|
91 |
with torch.no_grad():
|
92 |
past_key_values = set_past_key_values()
|
93 |
print(f'{past_key_values=}')
|
94 |
-
demo = gr.Interface(generate,
|
95 |
-
inputs=
|
96 |
demo.launch()
|
|
|
57 |
|
58 |
# past_key_values = PastKV()
|
59 |
|
60 |
+
|
61 |
def past_kv_to_device(past_kv, device):
|
62 |
+
return tuple((torch.tensor(k).to(device).detach(), torch.tensor(v).to(device).detach()) for k, v in past_kv)
|
63 |
|
64 |
def detach_past_kv(past_kv):
|
65 |
return tuple((k.cpu().detach().numpy().tolist(), v.cpu().detach().numpy().tolist()) for k, v in past_kv)
|
66 |
|
67 |
+
|
68 |
@spaces.GPU
|
69 |
def set_past_key_values():
|
70 |
model, tokenizer = pipe.model, pipe.tokenizer
|
|
|
93 |
with torch.no_grad():
|
94 |
past_key_values = set_past_key_values()
|
95 |
print(f'{past_key_values=}')
|
96 |
+
demo = gr.Interface(partial(generate, past_key_values=past_kv_to_device(past_key_values, pipe.model.device)),
|
97 |
+
inputs="textbox", outputs="textbox")
|
98 |
demo.launch()
|