dar-tau commited on
Commit
74d1efc
·
verified ·
1 Parent(s): 1937eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -54,13 +54,15 @@ start_messages = [
54
  class PastKV:
55
  past_key_values: Optional[torch.Tensor] = None
56
 
 
 
57
 
58
  def past_kv_to_device(past_kv, device):
59
  return tuple((k.to(device).detach(), v.to(device).detach()) for k, v in past_kv)
60
 
61
 
62
  @spaces.GPU
63
- def get_past_key_values(system_prompt):
64
  model, tokenizer = pipe.model, pipe.tokenizer
65
  tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
66
 
@@ -68,8 +70,9 @@ def get_past_key_values(system_prompt):
68
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
69
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
70
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
71
- past_key_values = model(tokenized.to(model.device)).past_key_values
72
- return PastKV(past_kv_to_device(past_key_values, 'cpu'))
 
73
 
74
  @spaces.GPU
75
  def generate(text, past_key_values):
@@ -84,7 +87,7 @@ def generate(text, past_key_values):
84
 
85
 
86
  if __name__ == "__main__":
87
- past_key_values = get_past_key_values(system_prompt)
88
  demo = gr.Interface(partial(generate, past_key_values=past_key_values.past_key_values),
89
  inputs="textbox", outputs="textbox")
90
  demo.launch()
 
54
  class PastKV:
55
  past_key_values: Optional[torch.Tensor] = None
56
 
57
+ past_key_values = PastKV()
58
+
59
 
60
  def past_kv_to_device(past_kv, device):
61
  return tuple((k.to(device).detach(), v.to(device).detach()) for k, v in past_kv)
62
 
63
 
64
  @spaces.GPU
65
+ def set_past_key_values():
66
  model, tokenizer = pipe.model, pipe.tokenizer
67
  tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
68
 
 
70
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
71
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
72
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
73
+ past_key_values.past_key_values = PastKV(past_kv_to_device(model(tokenized.to(model.device)).past_key_values, 'cpu'))
74
+ return True
75
+
76
 
77
  @spaces.GPU
78
  def generate(text, past_key_values):
 
87
 
88
 
89
  if __name__ == "__main__":
90
+ set_past_key_values()
91
  demo = gr.Interface(partial(generate, past_key_values=past_key_values.past_key_values),
92
  inputs="textbox", outputs="textbox")
93
  demo.launch()