dar-tau commited on
Commit
0faca03
·
verified ·
1 Parent(s): 8555522

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -45,6 +45,11 @@ start_messages = [
45
 
46
  torch.set_grad_enabled(False)
47
 
 
 
 
 
 
48
  @spaces.GPU
49
  def get_past_key_values(system_prompt):
50
  model, tokenizer = pipe.model, pipe.tokenizer
@@ -54,8 +59,8 @@ def get_past_key_values(system_prompt):
54
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
55
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
56
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
57
- return model(tokenized.to(model.device)).past_key_values.cpu().detach()
58
-
59
 
60
  @spaces.GPU
61
  def generate(text, past_key_values):
@@ -64,7 +69,7 @@ def generate(text, past_key_values):
64
  {'role': 'user', 'content': text}
65
  ]
66
  response = pipe(messages,
67
- past_key_values=past_key_values.to(model.device),
68
  **generate_kwargs)[0]['generated_text']
69
  return response[-1]['content']
70
 
 
45
 
46
  torch.set_grad_enabled(False)
47
 
48
+
49
+ def past_kv_to_device(past_kv, device):
50
+ return [(k.to(device).detach(), v.to(device).detach()) for k, v in past_kv.items()]
51
+
52
+
53
  @spaces.GPU
54
  def get_past_key_values(system_prompt):
55
  model, tokenizer = pipe.model, pipe.tokenizer
 
59
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
60
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
61
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
62
+ past_key_values = model(tokenized.to(model.device)).past_key_values
63
+ return past_kv_to_device(past_key_values, 'cpu')
64
 
65
  @spaces.GPU
66
  def generate(text, past_key_values):
 
69
  {'role': 'user', 'content': text}
70
  ]
71
  response = pipe(messages,
72
+ past_key_values=past_kv_to_device(past_key_values, model.device),
73
  **generate_kwargs)[0]['generated_text']
74
  return response[-1]['content']
75