Spaces:
Runtime error
Runtime error
Commit
·
f075841
1
Parent(s):
67d7a5a
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,16 @@
|
|
5 |
"""
|
6 |
import logging
|
7 |
import json
|
|
|
8 |
import sys
|
|
|
|
|
9 |
import warnings
|
10 |
import gradio as gr
|
11 |
from dataclasses import dataclass, field
|
12 |
from transformers import HfArgumentParser
|
13 |
from typing import Optional
|
14 |
-
|
15 |
from lmflow.datasets.dataset import Dataset
|
16 |
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
17 |
from lmflow.models.auto_model import AutoModel
|
@@ -70,13 +73,13 @@ css = """
|
|
70 |
@dataclass
|
71 |
class ChatbotArguments:
|
72 |
prompt_structure: Optional[str] = field(
|
73 |
-
default="
|
74 |
metadata={
|
75 |
"help": "prompt structure given user's input text"
|
76 |
},
|
77 |
)
|
78 |
end_string: Optional[str] = field(
|
79 |
-
default="
|
80 |
metadata={
|
81 |
"help": "end string mark of the chatbot's output"
|
82 |
},
|
@@ -94,7 +97,6 @@ class ChatbotArguments:
|
|
94 |
},
|
95 |
)
|
96 |
|
97 |
-
|
98 |
def main():
|
99 |
pipeline_name = "inferencer"
|
100 |
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
|
@@ -111,6 +113,7 @@ def main():
|
|
111 |
pipeline_args.deepspeed = "configs/ds_config_chatbot.json"
|
112 |
model_args.torch_dtype = "float16"
|
113 |
|
|
|
114 |
with open (pipeline_args.deepspeed, "r") as f:
|
115 |
ds_config = json.load(f)
|
116 |
|
@@ -119,6 +122,7 @@ def main():
|
|
119 |
tune_strategy='none',
|
120 |
ds_config=ds_config,
|
121 |
device=pipeline_args.device,
|
|
|
122 |
)
|
123 |
|
124 |
# We don't need input data, we will read interactively from stdin
|
@@ -150,19 +154,28 @@ def main():
|
|
150 |
|
151 |
token_per_step = 4
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
def chat_stream(
|
155 |
if history is None:
|
156 |
history = []
|
157 |
|
|
|
158 |
print_index = 0
|
159 |
context += prompt_structure.format(input_text=query)
|
160 |
-
|
161 |
input_dataset = dataset.from_dict({
|
162 |
"type": "text_only",
|
163 |
-
"instances": [ { "text":
|
164 |
})
|
165 |
-
|
|
|
166 |
token_per_step=token_per_step, temperature=chatbot_args.temperature,
|
167 |
end_string=end_string, input_dataset=input_dataset):
|
168 |
delta = response[print_index:]
|
@@ -171,22 +184,15 @@ def main():
|
|
171 |
|
172 |
yield delta, history + [(query, seq)]
|
173 |
if flag_break:
|
174 |
-
context += response + "\n"
|
175 |
break
|
176 |
|
177 |
|
178 |
|
179 |
|
180 |
def predict(input, history=None):
|
181 |
-
try:
|
182 |
-
global context
|
183 |
-
context = ""
|
184 |
-
except SyntaxError:
|
185 |
-
pass
|
186 |
-
|
187 |
if history is None:
|
188 |
history = []
|
189 |
-
for response, history in chat_stream(
|
190 |
updates = []
|
191 |
for query, response in history:
|
192 |
updates.append(gr.update(visible=True, value="" + query))
|
@@ -201,7 +207,6 @@ def main():
|
|
201 |
|
202 |
with gr.Blocks(css=css) as demo:
|
203 |
gr.HTML(title)
|
204 |
-
gr.HTML('''<center><a href="https://huggingface.co/spaces/OptimalScale/Robin-7b?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg" alt="Duplicate Space"></a></center>''')
|
205 |
state = gr.State([])
|
206 |
text_boxes = []
|
207 |
for i in range(MAX_BOXES):
|
@@ -221,6 +226,5 @@ def main():
|
|
221 |
|
222 |
|
223 |
|
224 |
-
|
225 |
if __name__ == "__main__":
|
226 |
main()
|
|
|
5 |
"""
|
6 |
import logging
|
7 |
import json
|
8 |
+
import os
|
9 |
import sys
|
10 |
+
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
|
11 |
+
import torch
|
12 |
import warnings
|
13 |
import gradio as gr
|
14 |
from dataclasses import dataclass, field
|
15 |
from transformers import HfArgumentParser
|
16 |
from typing import Optional
|
17 |
+
|
18 |
from lmflow.datasets.dataset import Dataset
|
19 |
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
20 |
from lmflow.models.auto_model import AutoModel
|
|
|
73 |
@dataclass
|
74 |
class ChatbotArguments:
|
75 |
prompt_structure: Optional[str] = field(
|
76 |
+
default="{input_text}",
|
77 |
metadata={
|
78 |
"help": "prompt structure given user's input text"
|
79 |
},
|
80 |
)
|
81 |
end_string: Optional[str] = field(
|
82 |
+
default="\n\n",
|
83 |
metadata={
|
84 |
"help": "end string mark of the chatbot's output"
|
85 |
},
|
|
|
97 |
},
|
98 |
)
|
99 |
|
|
|
100 |
def main():
|
101 |
pipeline_name = "inferencer"
|
102 |
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
|
|
|
113 |
pipeline_args.deepspeed = "configs/ds_config_chatbot.json"
|
114 |
model_args.torch_dtype = "float16"
|
115 |
|
116 |
+
|
117 |
with open (pipeline_args.deepspeed, "r") as f:
|
118 |
ds_config = json.load(f)
|
119 |
|
|
|
122 |
tune_strategy='none',
|
123 |
ds_config=ds_config,
|
124 |
device=pipeline_args.device,
|
125 |
+
torch_dtype=torch.float16
|
126 |
)
|
127 |
|
128 |
# We don't need input data, we will read interactively from stdin
|
|
|
154 |
|
155 |
token_per_step = 4
|
156 |
|
157 |
+
def hist2context(hist):
|
158 |
+
context = ""
|
159 |
+
for query, response in hist:
|
160 |
+
context += prompt_structure.format(input_text=query)
|
161 |
+
if not (response is None):
|
162 |
+
context += response
|
163 |
+
return context
|
164 |
|
165 |
+
def chat_stream(query: str, history= None, **kwargs):
|
166 |
if history is None:
|
167 |
history = []
|
168 |
|
169 |
+
context = hist2context(history)
|
170 |
print_index = 0
|
171 |
context += prompt_structure.format(input_text=query)
|
172 |
+
context_ = context[-model.get_max_length():]
|
173 |
input_dataset = dataset.from_dict({
|
174 |
"type": "text_only",
|
175 |
+
"instances": [ { "text": context_ } ]
|
176 |
})
|
177 |
+
print(context_)
|
178 |
+
for response, flag_break in inferencer.stream_inference(context=context_, model=model, max_new_tokens=chatbot_args.max_new_tokens,
|
179 |
token_per_step=token_per_step, temperature=chatbot_args.temperature,
|
180 |
end_string=end_string, input_dataset=input_dataset):
|
181 |
delta = response[print_index:]
|
|
|
184 |
|
185 |
yield delta, history + [(query, seq)]
|
186 |
if flag_break:
|
|
|
187 |
break
|
188 |
|
189 |
|
190 |
|
191 |
|
192 |
def predict(input, history=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
if history is None:
|
194 |
history = []
|
195 |
+
for response, history in chat_stream(input, history):
|
196 |
updates = []
|
197 |
for query, response in history:
|
198 |
updates.append(gr.update(visible=True, value="" + query))
|
|
|
207 |
|
208 |
with gr.Blocks(css=css) as demo:
|
209 |
gr.HTML(title)
|
|
|
210 |
state = gr.State([])
|
211 |
text_boxes = []
|
212 |
for i in range(MAX_BOXES):
|
|
|
226 |
|
227 |
|
228 |
|
|
|
229 |
if __name__ == "__main__":
|
230 |
main()
|