VictorSanh commited on
Commit
e137273
·
1 Parent(s): 01b5c05

make token streaming work

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -139,9 +139,6 @@ def model_inference(
139
 
140
  streamer = TextIteratorStreamer(
141
  PROCESSOR.tokenizer,
142
- decode_kwargs=dict(
143
- skip_special_tokens=True
144
- ),
145
  skip_prompt=True,
146
  )
147
  generation_kwargs = dict(
@@ -150,6 +147,16 @@ def model_inference(
150
  max_length=4096,
151
  streamer=streamer,
152
  )
 
 
 
 
 
 
 
 
 
 
153
  thread = Thread(
154
  target=MODEL.generate,
155
  kwargs=generation_kwargs,
@@ -157,15 +164,14 @@ def model_inference(
157
  thread.start()
158
  generated_text = ""
159
  for new_text in streamer:
 
 
 
 
 
160
  generated_text += new_text
161
- print("before yield")
162
- # yield generated_text, image
163
- print("after yield")
164
 
165
- # Sanity hack
166
- generated_text = generated_text.replace("</s>", "")
167
- rendered_page = render_webpage(generated_text)
168
- return generated_text, rendered_page
169
 
170
  generated_html = gr.Code(
171
  label="Extracted HTML",
@@ -234,26 +240,22 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
234
  fn=model_inference,
235
  inputs=[imagebox],
236
  outputs=[generated_html, rendered_html],
237
- queue=False,
238
  )
239
  regenerate_btn.click(
240
  fn=model_inference,
241
  inputs=[imagebox],
242
  outputs=[generated_html, rendered_html],
243
- queue=False,
244
  )
245
  template_gallery.select(
246
  fn=add_file_gallery,
247
  inputs=[template_gallery],
248
  outputs=[imagebox],
249
- queue=False,
250
  ).success(
251
  fn=model_inference,
252
  inputs=[imagebox],
253
  outputs=[generated_html, rendered_html],
254
- queue=False,
255
  )
256
- demo.load(queue=False)
257
 
258
  demo.queue(max_size=40, api_open=False)
259
  demo.launch(max_threads=400)
 
139
 
140
  streamer = TextIteratorStreamer(
141
  PROCESSOR.tokenizer,
 
 
 
142
  skip_prompt=True,
143
  )
144
  generation_kwargs = dict(
 
147
  max_length=4096,
148
  streamer=streamer,
149
  )
150
+ # Regular generation version
151
+ # generation_kwargs.pop("streamer")
152
+ # generated_ids = MODEL.generate(**generation_kwargs)
153
+ # generated_text = PROCESSOR.batch_decode(
154
+ # generated_ids,
155
+ # skip_special_tokens=True
156
+ # )[0]
157
+ # rendered_page = render_webpage(generated_text)
158
+ # return generated_text, rendered_page
159
+ # Token streaming version
160
  thread = Thread(
161
  target=MODEL.generate,
162
  kwargs=generation_kwargs,
 
164
  thread.start()
165
  generated_text = ""
166
  for new_text in streamer:
167
+ if "</s>" in new_text:
168
+ new_text = new_text.replace("</s>", "")
169
+ rendered_image = render_webpage(generated_text)
170
+ else:
171
+ rendered_image = None
172
  generated_text += new_text
173
+ yield generated_text, rendered_image
 
 
174
 
 
 
 
 
175
 
176
  generated_html = gr.Code(
177
  label="Extracted HTML",
 
240
  fn=model_inference,
241
  inputs=[imagebox],
242
  outputs=[generated_html, rendered_html],
 
243
  )
244
  regenerate_btn.click(
245
  fn=model_inference,
246
  inputs=[imagebox],
247
  outputs=[generated_html, rendered_html],
 
248
  )
249
  template_gallery.select(
250
  fn=add_file_gallery,
251
  inputs=[template_gallery],
252
  outputs=[imagebox],
 
253
  ).success(
254
  fn=model_inference,
255
  inputs=[imagebox],
256
  outputs=[generated_html, rendered_html],
 
257
  )
258
+ demo.load()
259
 
260
  demo.queue(max_size=40, api_open=False)
261
  demo.launch(max_threads=400)