Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -21,6 +21,10 @@ prompt_format = '''<|im_start|>system
|
|
21 |
<|im_start|>assistant
|
22 |
'''
|
23 |
|
|
|
|
|
|
|
|
|
24 |
|
25 |
system_prompt = '''You are given a partial input text for another AI chat interface.
|
26 |
Propose auto-completion to the text. You have several roles:
|
@@ -89,12 +93,12 @@ def detach_past_kv(past_kv):
|
|
89 |
@spaces.GPU
|
90 |
def set_past_key_values():
|
91 |
model, tokenizer = pipe.model, pipe.tokenizer
|
92 |
-
tokenized = tokenizer.
|
93 |
-
|
94 |
# Check that this is indeed a prefix of the entire message
|
95 |
-
test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
|
96 |
-
tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
|
97 |
-
assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
|
98 |
return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
|
99 |
|
100 |
|
@@ -118,8 +122,7 @@ def generate(text, past_key_values):
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
with torch.no_grad():
|
121 |
-
|
122 |
-
|
123 |
-
demo = gr.Interface(partial(generate, past_key_values=None),
|
124 |
inputs="textbox", outputs="textbox")
|
125 |
demo.launch()
|
|
|
21 |
<|im_start|>assistant
|
22 |
'''
|
23 |
|
24 |
+
system_only_prompt_format = '''<|im_start|>system
|
25 |
+
{system_message}<|im_end|>
|
26 |
+
<|im_start|>user
|
27 |
+
'''
|
28 |
|
29 |
system_prompt = '''You are given a partial input text for another AI chat interface.
|
30 |
Propose auto-completion to the text. You have several roles:
|
|
|
93 |
@spaces.GPU
|
94 |
def set_past_key_values():
|
95 |
model, tokenizer = pipe.model, pipe.tokenizer
|
96 |
+
tokenized = tokenizer(system_only_prompt_format.format(system_message=system_prompt))
|
97 |
+
# tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')
|
98 |
# Check that this is indeed a prefix of the entire message
|
99 |
+
# test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
|
100 |
+
# tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
|
101 |
+
# assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
|
102 |
return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
|
103 |
|
104 |
|
|
|
122 |
|
123 |
if __name__ == "__main__":
|
124 |
with torch.no_grad():
|
125 |
+
past_key_values = set_past_key_values()
|
126 |
+
demo = gr.Interface(partial(generate, past_key_values=past_key_values),
|
|
|
127 |
inputs="textbox", outputs="textbox")
|
128 |
demo.launch()
|