hendrydong commited on
Commit
f075841
·
1 Parent(s): 67d7a5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -5,13 +5,16 @@
5
  """
6
  import logging
7
  import json
 
8
  import sys
 
 
9
  import warnings
10
  import gradio as gr
11
  from dataclasses import dataclass, field
12
  from transformers import HfArgumentParser
13
  from typing import Optional
14
- import torch
15
  from lmflow.datasets.dataset import Dataset
16
  from lmflow.pipeline.auto_pipeline import AutoPipeline
17
  from lmflow.models.auto_model import AutoModel
@@ -70,13 +73,13 @@ css = """
70
  @dataclass
71
  class ChatbotArguments:
72
  prompt_structure: Optional[str] = field(
73
- default="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: {input_text}###Assistant:",
74
  metadata={
75
  "help": "prompt structure given user's input text"
76
  },
77
  )
78
  end_string: Optional[str] = field(
79
- default="#",
80
  metadata={
81
  "help": "end string mark of the chatbot's output"
82
  },
@@ -94,7 +97,6 @@ class ChatbotArguments:
94
  },
95
  )
96
 
97
-
98
  def main():
99
  pipeline_name = "inferencer"
100
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
@@ -111,6 +113,7 @@ def main():
111
  pipeline_args.deepspeed = "configs/ds_config_chatbot.json"
112
  model_args.torch_dtype = "float16"
113
 
 
114
  with open (pipeline_args.deepspeed, "r") as f:
115
  ds_config = json.load(f)
116
 
@@ -119,6 +122,7 @@ def main():
119
  tune_strategy='none',
120
  ds_config=ds_config,
121
  device=pipeline_args.device,
 
122
  )
123
 
124
  # We don't need input data, we will read interactively from stdin
@@ -150,19 +154,28 @@ def main():
150
 
151
  token_per_step = 4
152
 
 
 
 
 
 
 
 
153
 
154
- def chat_stream( context, query: str, history= None, **kwargs):
155
  if history is None:
156
  history = []
157
 
 
158
  print_index = 0
159
  context += prompt_structure.format(input_text=query)
160
- context = context[-model.get_max_length():]
161
  input_dataset = dataset.from_dict({
162
  "type": "text_only",
163
- "instances": [ { "text": context } ]
164
  })
165
- for response, flag_break in inferencer.stream_inference(context=context, model=model, max_new_tokens=chatbot_args.max_new_tokens,
 
166
  token_per_step=token_per_step, temperature=chatbot_args.temperature,
167
  end_string=end_string, input_dataset=input_dataset):
168
  delta = response[print_index:]
@@ -171,22 +184,15 @@ def main():
171
 
172
  yield delta, history + [(query, seq)]
173
  if flag_break:
174
- context += response + "\n"
175
  break
176
 
177
 
178
 
179
 
180
  def predict(input, history=None):
181
- try:
182
- global context
183
- context = ""
184
- except SyntaxError:
185
- pass
186
-
187
  if history is None:
188
  history = []
189
- for response, history in chat_stream(context, input, history):
190
  updates = []
191
  for query, response in history:
192
  updates.append(gr.update(visible=True, value="" + query))
@@ -201,7 +207,6 @@ def main():
201
 
202
  with gr.Blocks(css=css) as demo:
203
  gr.HTML(title)
204
- gr.HTML('''<center><a href="https://huggingface.co/spaces/OptimalScale/Robin-7b?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg" alt="Duplicate Space"></a></center>''')
205
  state = gr.State([])
206
  text_boxes = []
207
  for i in range(MAX_BOXES):
@@ -221,6 +226,5 @@ def main():
221
 
222
 
223
 
224
-
225
  if __name__ == "__main__":
226
  main()
 
5
  """
6
  import logging
7
  import json
8
+ import os
9
  import sys
10
+ sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
11
+ import torch
12
  import warnings
13
  import gradio as gr
14
  from dataclasses import dataclass, field
15
  from transformers import HfArgumentParser
16
  from typing import Optional
17
+
18
  from lmflow.datasets.dataset import Dataset
19
  from lmflow.pipeline.auto_pipeline import AutoPipeline
20
  from lmflow.models.auto_model import AutoModel
 
73
  @dataclass
74
  class ChatbotArguments:
75
  prompt_structure: Optional[str] = field(
76
+ default="{input_text}",
77
  metadata={
78
  "help": "prompt structure given user's input text"
79
  },
80
  )
81
  end_string: Optional[str] = field(
82
+ default="\n\n",
83
  metadata={
84
  "help": "end string mark of the chatbot's output"
85
  },
 
97
  },
98
  )
99
 
 
100
  def main():
101
  pipeline_name = "inferencer"
102
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
 
113
  pipeline_args.deepspeed = "configs/ds_config_chatbot.json"
114
  model_args.torch_dtype = "float16"
115
 
116
+
117
  with open (pipeline_args.deepspeed, "r") as f:
118
  ds_config = json.load(f)
119
 
 
122
  tune_strategy='none',
123
  ds_config=ds_config,
124
  device=pipeline_args.device,
125
+ torch_dtype=torch.float16
126
  )
127
 
128
  # We don't need input data, we will read interactively from stdin
 
154
 
155
  token_per_step = 4
156
 
157
+ def hist2context(hist):
158
+ context = ""
159
+ for query, response in hist:
160
+ context += prompt_structure.format(input_text=query)
161
+ if not (response is None):
162
+ context += response
163
+ return context
164
 
165
+ def chat_stream(query: str, history= None, **kwargs):
166
  if history is None:
167
  history = []
168
 
169
+ context = hist2context(history)
170
  print_index = 0
171
  context += prompt_structure.format(input_text=query)
172
+ context_ = context[-model.get_max_length():]
173
  input_dataset = dataset.from_dict({
174
  "type": "text_only",
175
+ "instances": [ { "text": context_ } ]
176
  })
177
+ print(context_)
178
+ for response, flag_break in inferencer.stream_inference(context=context_, model=model, max_new_tokens=chatbot_args.max_new_tokens,
179
  token_per_step=token_per_step, temperature=chatbot_args.temperature,
180
  end_string=end_string, input_dataset=input_dataset):
181
  delta = response[print_index:]
 
184
 
185
  yield delta, history + [(query, seq)]
186
  if flag_break:
 
187
  break
188
 
189
 
190
 
191
 
192
  def predict(input, history=None):
 
 
 
 
 
 
193
  if history is None:
194
  history = []
195
+ for response, history in chat_stream(input, history):
196
  updates = []
197
  for query, response in history:
198
  updates.append(gr.update(visible=True, value="" + query))
 
207
 
208
  with gr.Blocks(css=css) as demo:
209
  gr.HTML(title)
 
210
  state = gr.State([])
211
  text_boxes = []
212
  for i in range(MAX_BOXES):
 
226
 
227
 
228
 
 
229
  if __name__ == "__main__":
230
  main()