dar-tau commited on
Commit
6b3281f
·
verified ·
1 Parent(s): 497a54c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -21,6 +21,10 @@ prompt_format = '''<|im_start|>system
21
  <|im_start|>assistant
22
  '''
23
 
 
 
 
 
24
 
25
  system_prompt = '''You are given a partial input text for another AI chat interface.
26
  Propose auto-completion to the text. You have several roles:
@@ -89,12 +93,12 @@ def detach_past_kv(past_kv):
89
  @spaces.GPU
90
  def set_past_key_values():
91
  model, tokenizer = pipe.model, pipe.tokenizer
92
- tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
93
-
94
  # Check that this is indeed a prefix of the entire message
95
- test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
96
- tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
97
- assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
98
  return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
99
 
100
 
@@ -118,8 +122,7 @@ def generate(text, past_key_values):
118
 
119
  if __name__ == "__main__":
120
  with torch.no_grad():
121
- # past_key_values = set_past_key_values()
122
- # print(f'{past_key_values=}')
123
- demo = gr.Interface(partial(generate, past_key_values=None),
124
  inputs="textbox", outputs="textbox")
125
  demo.launch()
 
21
  <|im_start|>assistant
22
  '''
23
 
24
+ system_only_prompt_format = '''<|im_start|>system
25
+ {system_message}<|im_end|>
26
+ <|im_start|>user
27
+ '''
28
 
29
  system_prompt = '''You are given a partial input text for another AI chat interface.
30
  Propose auto-completion to the text. You have several roles:
 
93
  @spaces.GPU
94
  def set_past_key_values():
95
  model, tokenizer = pipe.model, pipe.tokenizer
96
+ tokenized = tokenizer(system_only_prompt_format.format(system_message=system_prompt))
97
+ # tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
98
  # Check that this is indeed a prefix of the entire message
99
+ # test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
100
+ # tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
101
+ # assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
102
  return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
103
 
104
 
 
122
 
123
  if __name__ == "__main__":
124
  with torch.no_grad():
125
+ past_key_values = set_past_key_values()
126
+ demo = gr.Interface(partial(generate, past_key_values=past_key_values),
 
127
  inputs="textbox", outputs="textbox")
128
  demo.launch()