prithivMLmods commited on
Commit
e576635
·
verified ·
1 Parent(s): 8c1f8ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -119,7 +119,6 @@ def model_inference(input_dict, history, use_rolmocr=False):
119
  model = rolmocr_model if use_rolmocr else qwen_model
120
  model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
121
 
122
- # Prepare prompt and inputs
123
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
  all_images = [item["image"] for item in content if item["type"] == "image"]
125
  inputs = processor(
@@ -129,7 +128,6 @@ def model_inference(input_dict, history, use_rolmocr=False):
129
  padding=True,
130
  ).to("cuda")
131
 
132
- # Set up streaming
133
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -138,23 +136,17 @@ def model_inference(input_dict, history, use_rolmocr=False):
138
  buffer = ""
139
  yield progress_bar_html(f"Processing with {model_name}")
140
 
141
- # Stream tokens
142
  for new_text in streamer:
143
- buffer += new_text
144
- buffer = buffer.replace("<|im_end|>", "")
145
  time.sleep(0.01)
146
  yield buffer
147
 
148
- # Once streaming is done, save to response.txt and yield final result
149
- results = buffer.strip()
150
- try:
151
- with open("response.txt", "w", encoding="utf-8") as f:
152
- f.write(results)
153
- except Exception as e:
154
- yield f"Error writing to response.txt: {e}"
155
- return
156
 
157
- yield results
158
  return
159
 
160
  # Gradio Interface
@@ -180,4 +172,5 @@ demo = gr.ChatInterface(
180
  additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
181
  )
182
 
183
- demo.launch(debug=True)
 
 
119
  model = rolmocr_model if use_rolmocr else qwen_model
120
  model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
121
 
 
122
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
123
  all_images = [item["image"] for item in content if item["type"] == "image"]
124
  inputs = processor(
 
128
  padding=True,
129
  ).to("cuda")
130
 
 
131
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
132
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
133
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
136
  buffer = ""
137
  yield progress_bar_html(f"Processing with {model_name}")
138
 
139
+ # Stream generation
140
  for new_text in streamer:
141
+ buffer += new_text.replace("<|im_end|>", "")
 
142
  time.sleep(0.01)
143
  yield buffer
144
 
145
+ # Once complete, save to response.txt and yield final confirmation
146
+ with open("response.txt", "w", encoding="utf-8") as f:
147
+ f.write(buffer)
 
 
 
 
 
148
 
149
+ yield f"\n✅ Response saved to `response.txt`:\n\n{buffer}"
150
  return
151
 
152
  # Gradio Interface
 
172
  additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
173
  )
174
 
175
+ if __name__ == "__main__":
176
+ demo.launch(debug=True)