Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,11 +9,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
9 |
from dataclasses import dataclass
|
10 |
|
11 |
|
12 |
-
# chatml_template = """{% for message in messages %}
|
13 |
-
# {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
14 |
-
# {% endfor %}"""
|
15 |
-
# pipe.tokenizer.chat_template = chatml_template # TheBloke says this is the right template for this model
|
16 |
-
|
17 |
prompt_format = '''<|im_start|>system
|
18 |
{system_message}<|im_end|>
|
19 |
<|im_start|>user
|
@@ -51,37 +46,10 @@ Assistant: girlfriend;mother;father;friend
|
|
51 |
|
52 |
# setup
|
53 |
torch.set_grad_enabled(False)
|
54 |
-
device = "cpu"
|
55 |
model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
|
56 |
pipe = pipeline("text-generation", model=model_name, device='cuda')
|
57 |
generate_kwargs = {'max_new_tokens': 20}
|
58 |
|
59 |
-
# '''
|
60 |
-
# You will now get a blank message from the user and then after your answer, the user will give you the text to complete:
|
61 |
-
# Example:
|
62 |
-
|
63 |
-
# >> User:
|
64 |
-
# >> Assistant: <Waiting for text>
|
65 |
-
# >> User: Help me write a sentiment analysis pipeline
|
66 |
-
# >> Assistant: using huggingface;using NLTK;using python
|
67 |
-
# '''
|
68 |
-
|
69 |
-
|
70 |
-
start_messages = [
|
71 |
-
{'role': 'system', 'content': system_prompt},
|
72 |
-
# {'role': 'user', 'content': ' '},
|
73 |
-
# {'role': 'assistant', 'content': '<Waiting for text>'}
|
74 |
-
]
|
75 |
-
|
76 |
-
|
77 |
-
# functions
|
78 |
-
# @dataclass
|
79 |
-
# class PastKV:
|
80 |
-
# past_key_values: Any = None
|
81 |
-
|
82 |
-
# past_key_values = PastKV()
|
83 |
-
|
84 |
-
|
85 |
def past_kv_to_device(past_kv, device, dtype):
|
86 |
return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
|
87 |
|
@@ -104,20 +72,17 @@ def set_past_key_values():
|
|
104 |
return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
|
105 |
|
106 |
|
107 |
-
|
108 |
def generate(text, past_key_values):
|
109 |
-
# messages = [
|
110 |
-
# *start_messages,
|
111 |
-
# {'role': 'user', 'content': text}
|
112 |
-
# ]
|
113 |
-
|
114 |
cur_generate_kwargs = deepcopy(generate_kwargs)
|
115 |
|
116 |
if past_key_values:
|
117 |
past_key_values = past_kv_to_device(past_key_values, pipe.model.device, pipe.model.dtype)
|
118 |
cur_generate_kwargs.update({'past_key_values': past_key_values})
|
119 |
|
120 |
-
response = pipe(
|
|
|
|
|
121 |
print(response)
|
122 |
return response[-1]['content']
|
123 |
|
@@ -126,6 +91,8 @@ if __name__ == "__main__":
|
|
126 |
with torch.no_grad():
|
127 |
past_key_values = set_past_key_values()
|
128 |
pipe.model = pipe.model.cpu()
|
129 |
-
demo = gr.Interface(
|
130 |
-
|
|
|
|
|
131 |
demo.launch()
|
|
|
9 |
from dataclasses import dataclass
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
prompt_format = '''<|im_start|>system
|
13 |
{system_message}<|im_end|>
|
14 |
<|im_start|>user
|
|
|
46 |
|
47 |
# setup
|
48 |
torch.set_grad_enabled(False)
|
|
|
49 |
model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
|
50 |
pipe = pipeline("text-generation", model=model_name, device='cuda')
|
51 |
generate_kwargs = {'max_new_tokens': 20}
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def past_kv_to_device(past_kv, device, dtype):
|
54 |
return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
|
55 |
|
|
|
72 |
return detach_past_kv(model(tokenized.to(model.device)).past_key_values)
|
73 |
|
74 |
|
75 |
+
@spaces.GPU
|
76 |
def generate(text, past_key_values):
|
|
|
|
|
|
|
|
|
|
|
77 |
cur_generate_kwargs = deepcopy(generate_kwargs)
|
78 |
|
79 |
if past_key_values:
|
80 |
past_key_values = past_kv_to_device(past_key_values, pipe.model.device, pipe.model.dtype)
|
81 |
cur_generate_kwargs.update({'past_key_values': past_key_values})
|
82 |
|
83 |
+
response = pipe(
|
84 |
+
prompt_format.format(system_message=system_prompt, prompt=text), **cur_generate_kwargs
|
85 |
+
)[0]['generated_text']
|
86 |
print(response)
|
87 |
return response[-1]['content']
|
88 |
|
|
|
91 |
with torch.no_grad():
|
92 |
past_key_values = set_past_key_values()
|
93 |
pipe.model = pipe.model.cpu()
|
94 |
+
demo = gr.Interface(
|
95 |
+
partial(generate, past_key_values=past_key_values),
|
96 |
+
inputs="textbox", outputs="textbox"
|
97 |
+
)
|
98 |
demo.launch()
|