prithivMLmods commited on
Commit
e29072f
·
verified ·
1 Parent(s): f396cd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -267
app.py CHANGED
@@ -16,7 +16,6 @@ import cv2
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
- VisionEncoderDecoderModel,
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
@@ -45,15 +44,6 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
  torch_dtype=torch.float16
46
  ).to(device).eval()
47
 
48
- # Load ByteDance's Dolphin
49
- MODEL_ID_K = "ByteDance/Dolphin"
50
- processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
51
- model_k = VisionEncoderDecoderModel.from_pretrained(
52
- MODEL_ID_K,
53
- trust_remote_code=True,
54
- torch_dtype=torch.float16
55
- ).to(device).eval()
56
-
57
  # Load SmolDocling-256M-preview
58
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
59
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
@@ -124,104 +114,6 @@ def downsample_video(video_path):
124
  vidcap.release()
125
  return frames
126
 
127
- # Dolphin-specific functions
128
- def model_chat(prompt, image):
129
- """Use Dolphin model for inference."""
130
- processor = processor_k
131
- model = model_k
132
- device = "cuda" if torch.cuda.is_available() else "cpu"
133
- inputs = processor(image, return_tensors="pt").to(device)
134
- pixel_values = inputs.pixel_values.half()
135
- prompt_inputs = processor.tokenizer(
136
- f"<s>{prompt} <Answer/>",
137
- add_special_tokens=False,
138
- return_tensors="pt"
139
- ).to(device)
140
- outputs = model.generate(
141
- pixel_values=pixel_values,
142
- decoder_input_ids=prompt_inputs.input_ids,
143
- decoder_attention_mask=prompt_inputs.attention_mask,
144
- min_length=1,
145
- max_length=4096,
146
- pad_token_id=processor.tokenizer.pad_token_id,
147
- eos_token_id=processor.tokenizer.eos_token_id,
148
- use_cache=False, # Changed to False to avoid deprecation warning
149
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
150
- return_dict_in_generate=True,
151
- do_sample=False,
152
- num_beams=1,
153
- repetition_penalty=1.1
154
- )
155
- sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
- cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
157
- return cleaned
158
-
159
- def process_elements(layout_results, image):
160
- """Parse layout results and extract elements from the image."""
161
- # Placeholder parsing logic based on expected Dolphin output
162
- # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
163
- try:
164
- elements = ast.literal_eval(layout_results)
165
- except:
166
- elements = [] # Fallback if parsing fails
167
-
168
- recognition_results = []
169
- reading_order = 0
170
-
171
- for bbox, label in elements:
172
- try:
173
- x1, y1, x2, y2 = map(int, bbox)
174
- cropped = image.crop((x1, y1, x2, y2))
175
- if cropped.size[0] > 0 and cropped.size[1] > 0:
176
- if label == "text":
177
- text = model_chat("Read text in the image.", cropped)
178
- recognition_results.append({
179
- "label": label,
180
- "bbox": [x1, y1, x2, y2],
181
- "text": text.strip(),
182
- "reading_order": reading_order
183
- })
184
- elif label == "table":
185
- table_text = model_chat("Parse the table in the image.", cropped)
186
- recognition_results.append({
187
- "label": label,
188
- "bbox": [x1, y1, x2, y2],
189
- "text": table_text.strip(),
190
- "reading_order": reading_order
191
- })
192
- elif label == "figure":
193
- recognition_results.append({
194
- "label": label,
195
- "bbox": [x1, y1, x2, y2],
196
- "text": "[Figure]", # Placeholder for figure content
197
- "reading_order": reading_order
198
- })
199
- reading_order += 1
200
- except Exception as e:
201
- print(f"Error processing element: {e}")
202
- continue
203
-
204
- return recognition_results
205
-
206
- def generate_markdown(recognition_results):
207
- """Generate markdown from extracted elements."""
208
- markdown = ""
209
- for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
210
- if element["label"] == "text":
211
- markdown += f"{element['text']}\n\n"
212
- elif element["label"] == "table":
213
- markdown += f"**Table:**\n{element['text']}\n\n"
214
- elif element["label"] == "figure":
215
- markdown += f"{element['text']}\n\n"
216
- return markdown.strip()
217
-
218
- def process_image_with_dolphin(image):
219
- """Process a single image with Dolphin model."""
220
- layout_output = model_chat("Parse the reading order of this document.", image)
221
- elements = process_elements(layout_output, image)
222
- markdown_content = generate_markdown(elements)
223
- return markdown_content
224
-
225
  @spaces.GPU
226
  def generate_image(model_name: str, text: str, image: Image.Image,
227
  max_new_tokens: int = 1024,
@@ -230,82 +122,81 @@ def generate_image(model_name: str, text: str, image: Image.Image,
230
  top_k: int = 50,
231
  repetition_penalty: float = 1.2):
232
  """Generate responses for image input using the selected model."""
233
- if model_name == "ByteDance-s-Dolphin":
234
- if image is None:
235
- yield "Please upload an image."
236
- return
237
- markdown_content = process_image_with_dolphin(image)
238
- yield markdown_content
 
 
 
 
239
  else:
240
- # Existing logic for other models
241
- if model_name == "Nanonets-OCR-s":
242
- processor = processor_m
243
- model = model_m
244
- elif model_name == "MonkeyOCR-Recognition":
245
- processor = processor_g
246
- model = model_g
247
- elif model_name == "SmolDocling-256M-preview":
248
- processor = processor_x
249
- model = model_x
250
- else:
251
- yield "Invalid model selected."
252
- return
253
-
254
- if image is None:
255
- yield "Please upload an image."
256
- return
257
-
258
- images = [image]
259
-
260
- if model_name == "SmolDocling-256M-preview":
261
- if "OTSL" in text or "code" in text:
262
- images = [add_random_padding(img) for img in images]
263
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
264
- text = normalize_values(text, target_max=500)
265
-
266
- messages = [
267
- {
268
- "role": "user",
269
- "content": [{"type": "image"} for _ in images] + [
270
- {"type": "text", "text": text}
271
- ]
272
- }
273
- ]
274
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
275
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
276
-
277
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
278
- generation_kwargs = {
279
- **inputs,
280
- "streamer": streamer,
281
- "max_new_tokens": max_new_tokens,
282
- "temperature": temperature,
283
- "top_p": top_p,
284
- "top_k": top_k,
285
- "repetition_penalty": repetition_penalty,
286
  }
287
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
288
- thread.start()
289
-
290
- buffer = ""
291
- full_output = ""
292
- for new_text in streamer:
293
- full_output += new_text
294
- buffer += new_text.replace("<|im_end|>", "")
295
- yield buffer
296
-
297
- if model_name == "SmolDocling-256M-preview":
298
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
299
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
300
- if "<chart>" in cleaned_output:
301
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
302
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
303
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
304
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
305
- markdown_output = doc.export_to_markdown()
306
- yield f"**MD Output:**\n\n{markdown_output}"
307
- else:
308
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  @spaces.GPU
311
  def generate_video(model_name: str, text: str, video_path: str,
@@ -315,94 +206,88 @@ def generate_video(model_name: str, text: str, video_path: str,
315
  top_k: int = 50,
316
  repetition_penalty: float = 1.2):
317
  """Generate responses for video input using the selected model."""
318
- if model_name == "ByteDance-s-Dolphin":
319
- if video_path is None:
320
- yield "Please upload a video."
321
- return
322
- frames = downsample_video(video_path)
323
- markdown_contents = []
324
- for frame, _ in frames:
325
- markdown_content = process_image_with_dolphin(frame)
326
- markdown_contents.append(markdown_content)
327
- combined_markdown = "\n\n".join(markdown_contents)
328
- yield combined_markdown
329
  else:
330
- # Existing logic for other models
331
- if model_name == "Nanonets-OCR-s":
332
- processor = processor_m
333
- model = model_m
334
- elif model_name == "MonkeyOCR-Recognition":
335
- processor = processor_g
336
- model = model_g
337
- elif model_name == "SmolDocling-256M-preview":
338
- processor = processor_x
339
- model = model_x
340
- else:
341
- yield "Invalid model selected."
342
- return
343
-
344
- if video_path is None:
345
- yield "Please upload a video."
346
- return
347
-
348
- frames = downsample_video(video_path)
349
- images = [frame for frame, _ in frames]
350
-
351
- if model_name == "SmolDocling-256M-preview":
352
- if "OTSL" in text or "code" in text:
353
- images = [add_random_padding(img) for img in images]
354
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
355
- text = normalize_values(text, target_max=500)
356
-
357
- messages = [
358
- {
359
- "role": "user",
360
- "content": [{"type": "image"} for _ in images] + [
361
- {"type": "text", "text": text}
362
- ]
363
- }
364
- ]
365
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
-
368
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
- generation_kwargs = {
370
- **inputs,
371
- "streamer": streamer,
372
- "max_new_tokens": max_new_tokens,
373
- "temperature": temperature,
374
- "top_p": top_p,
375
- "top_k": top_k,
376
- "repetition_penalty": repetition_penalty,
377
  }
378
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
379
- thread.start()
380
-
381
- buffer = ""
382
- full_output = ""
383
- for new_text in streamer:
384
- full_output += new_text
385
- buffer += new_text.replace("<|im_end|>", "")
386
- yield buffer
387
-
388
- if model_name == "SmolDocling-256M-preview":
389
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
390
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
391
- if "<chart>" in cleaned_output:
392
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
393
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
394
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
395
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
396
- markdown_output = doc.export_to_markdown()
397
- yield f"**MD Output:**\n\n{markdown_output}"
398
- else:
399
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  # Define examples for image and video inference
402
  image_examples = [
403
- ["Convert this page to docling", "images/1.png"],
404
- ["OCR the image", "images/2.jpg"],
405
- ["Convert this page to docling", "images/3.png"],
406
  ]
407
 
408
  video_examples = [
@@ -422,7 +307,7 @@ css = """
422
 
423
  # Create the Gradio Interface
424
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
425
- gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
426
  with gr.Row():
427
  with gr.Column():
428
  with gr.Tabs():
@@ -451,7 +336,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
451
  with gr.Column():
452
  output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
453
  model_choice = gr.Radio(
454
- choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
455
  label="Select Model",
456
  value="Nanonets-OCR-s"
457
  )
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
 
19
  AutoModelForVision2Seq,
20
  AutoProcessor,
21
  TextIteratorStreamer,
 
44
  torch_dtype=torch.float16
45
  ).to(device).eval()
46
 
 
 
 
 
 
 
 
 
 
47
  # Load SmolDocling-256M-preview
48
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
49
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
 
114
  vidcap.release()
115
  return frames
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @spaces.GPU
118
  def generate_image(model_name: str, text: str, image: Image.Image,
119
  max_new_tokens: int = 1024,
 
122
  top_k: int = 50,
123
  repetition_penalty: float = 1.2):
124
  """Generate responses for image input using the selected model."""
125
+ # Model selection
126
+ if model_name == "Nanonets-OCR-s":
127
+ processor = processor_m
128
+ model = model_m
129
+ elif model_name == "MonkeyOCR-Recognition":
130
+ processor = processor_g
131
+ model = model_g
132
+ elif model_name == "SmolDocling-256M-preview":
133
+ processor = processor_x
134
+ model = model_x
135
  else:
136
+ yield "Invalid model selected."
137
+ return
138
+
139
+ if image is None:
140
+ yield "Please upload an image."
141
+ return
142
+
143
+ # Prepare images as a list (single image for image inference)
144
+ images = [image]
145
+
146
+ # SmolDocling-256M specific preprocessing
147
+ if model_name == "SmolDocling-256M-preview":
148
+ if "OTSL" in text or "code" in text:
149
+ images = [add_random_padding(img) for img in images]
150
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
151
+ text = normalize_values(text, target_max=500)
152
+
153
+ # Unified message structure for all models
154
+ messages = [
155
+ {
156
+ "role": "user",
157
+ "content": [{"type": "image"} for _ in images] + [
158
+ {"type": "text", "text": text}
159
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  }
161
+ ]
162
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
163
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
164
+
165
+ # Generation with streaming
166
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
167
+ generation_kwargs = {
168
+ **inputs,
169
+ "streamer": streamer,
170
+ "max_new_tokens": max_new_tokens,
171
+ "temperature": temperature,
172
+ "top_p": top_p,
173
+ "top_k": top_k,
174
+ "repetition_penalty": repetition_penalty,
175
+ }
176
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
177
+ thread.start()
178
+
179
+ # Stream output and collect full response
180
+ buffer = ""
181
+ full_output = ""
182
+ for new_text in streamer:
183
+ full_output += new_text
184
+ buffer += new_text.replace("<|im_end|>", "")
185
+ yield buffer
186
+
187
+ # SmolDocling-256M specific postprocessing
188
+ if model_name == "SmolDocling-256M-preview":
189
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
190
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
191
+ if "<chart>" in cleaned_output:
192
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
193
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
194
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
195
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
196
+ markdown_output = doc.export_to_markdown()
197
+ yield f"**MD Output:**\n\n{markdown_output}"
198
+ else:
199
+ yield cleaned_output
200
 
201
  @spaces.GPU
202
  def generate_video(model_name: str, text: str, video_path: str,
 
206
  top_k: int = 50,
207
  repetition_penalty: float = 1.2):
208
  """Generate responses for video input using the selected model."""
209
+ # Model selection
210
+ if model_name == "Nanonets-OCR-s":
211
+ processor = processor_m
212
+ model = model_m
213
+ elif model_name == "MonkeyOCR-Recognition":
214
+ processor = processor_g
215
+ model = model_g
216
+ elif model_name == "SmolDocling-256M-preview":
217
+ processor = processor_x
218
+ model = model_x
 
219
  else:
220
+ yield "Invalid model selected."
221
+ return
222
+
223
+ if video_path is None:
224
+ yield "Please upload a video."
225
+ return
226
+
227
+ # Extract frames from video
228
+ frames = downsample_video(video_path)
229
+ images = [frame for frame, _ in frames]
230
+
231
+ # SmolDocling-256M specific preprocessing
232
+ if model_name == "SmolDocling-256M-preview":
233
+ if "OTSL" in text or "code" in text:
234
+ images = [add_random_padding(img) for img in images]
235
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
236
+ text = normalize_values(text, target_max=500)
237
+
238
+ # Unified message structure for all models
239
+ messages = [
240
+ {
241
+ "role": "user",
242
+ "content": [{"type": "image"} for _ in images] + [
243
+ {"type": "text", "text": text}
244
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }
246
+ ]
247
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
248
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
249
+
250
+ # Generation with streaming
251
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
252
+ generation_kwargs = {
253
+ **inputs,
254
+ "streamer": streamer,
255
+ "max_new_tokens": max_new_tokens,
256
+ "temperature": temperature,
257
+ "top_p": top_p,
258
+ "top_k": top_k,
259
+ "repetition_penalty": repetition_penalty,
260
+ }
261
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
262
+ thread.start()
263
+
264
+ # Stream output and collect full response
265
+ buffer = ""
266
+ full_output = ""
267
+ for new_text in streamer:
268
+ full_output += new_text
269
+ buffer += new_text.replace("<|im_end|>", "")
270
+ yield buffer
271
+
272
+ # SmolDocling-256M specific postprocessing
273
+ if model_name == "SmolDocling-256M-preview":
274
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
275
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
276
+ if "<chart>" in cleaned_output:
277
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
278
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
279
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
280
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
281
+ markdown_output = doc.export_to_markdown()
282
+ yield f"**MD Output:**\n\n{markdown_output}"
283
+ else:
284
+ yield cleaned_output
285
 
286
  # Define examples for image and video inference
287
  image_examples = [
288
+ ["fill the correct numbers", "example/image3.png"],
289
+ ["ocr the image", "example/image1.png"],
290
+ ["explain the scene", "example/image2.jpg"],
291
  ]
292
 
293
  video_examples = [
 
307
 
308
  # Create the Gradio Interface
309
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
310
+ gr.Markdown("# **[core OCR](https://huggingface.co/collections/prithivMLmods/core-and-docscope-ocr-models-6816d7f1bde3f911c6c852bc)**")
311
  with gr.Row():
312
  with gr.Column():
313
  with gr.Tabs():
 
336
  with gr.Column():
337
  output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
338
  model_choice = gr.Radio(
339
+ choices=["Nanonets-OCR-s", "MonkeyOCR-Recognition", "SmolDocling-256M-preview"],
340
  label="Select Model",
341
  value="Nanonets-OCR-s"
342
  )