dar-tau commited on
Commit
92585dc
·
verified ·
1 Parent(s): d3017cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -33,34 +33,42 @@ Assistant: "" (nothing much to contribute at this point. return nothing)
33
  (3)
34
  User: "Help me find a present for my"
35
  Assistant: "girlfriend;mother;father;friend"
 
36
  '''
37
 
 
 
 
 
 
 
38
 
39
  @spaces.GPU
40
  def get_past_key_values(system_prompt):
41
  model, tokenizer = pipe.model, pipe.tokenizer
42
- messages = [{'role': 'system', 'content': system_prompt}]
43
- tokenized = tokenizer.apply_chat_template(messages, return_tensors='pt')
44
 
45
- # assert that this is indeed a prefix (TODO: make sure this is robust)
46
- messages.append({'role': 'user', 'content': 'TEST'})
47
- tokenized_test = tokenizer.apply_chat_template(messages, return_tensors='pt')
48
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
49
-
50
  return model(tokenized.to(model.device)).past_key_values
51
 
52
 
53
  @spaces.GPU
54
- def generate(text):
55
  messages = [
56
- {'role': 'system', 'content': system_prompt},
57
  {'role': 'user', 'content': text}
58
  ]
59
- response = pipe(messages, **generate_kwargs)[0]['generated_text']
 
 
60
  return response[-1]['content']
61
 
62
 
63
  if __name__ == "__main__":
64
  past_key_values = get_past_key_values(system_prompt)
65
- demo = gr.Interface(generate, inputs="textbox", outputs="textbox")
 
66
  demo.launch()
 
33
  (3)
34
  User: "Help me find a present for my"
35
  Assistant: "girlfriend;mother;father;friend"
36
+ You will now get a blank message from the user and then after your answer, the user will give you the text to complete.
37
  '''
38
 
39
+ start_messages = [
40
+ {'role': 'system', 'content': system_prompt},
41
+ {'role': 'user', 'content': ' '},
42
+ {'role': 'assistant', 'content': '<Waiting for text>'}
43
+ ]
44
+
45
 
46
  @spaces.GPU
47
  def get_past_key_values(system_prompt):
48
  model, tokenizer = pipe.model, pipe.tokenizer
49
+ tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
 
50
 
51
+ # Check that this is indeed a prefix of the entire message
52
+ test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
53
+ tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
54
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
 
55
  return model(tokenized.to(model.device)).past_key_values
56
 
57
 
58
  @spaces.GPU
59
+ def generate(text, past_key_values):
60
  messages = [
61
+ *start_messages,
62
  {'role': 'user', 'content': text}
63
  ]
64
+ response = pipe(messages,
65
+ past_key_values=past_key_values,
66
+ **generate_kwargs)[0]['generated_text']
67
  return response[-1]['content']
68
 
69
 
70
  if __name__ == "__main__":
71
  past_key_values = get_past_key_values(system_prompt)
72
+ demo = gr.Interface(partial(generate, past_key_values=past_key_values),
73
+ inputs="textbox", outputs="textbox")
74
  demo.launch()