prithivMLmods commited on
Commit
c97526c
·
verified ·
1 Parent(s): 5cdeb4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -280
app.py CHANGED
@@ -20,9 +20,9 @@ from transformers import (
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
- EncoderDecoderCache # Added to handle the new caching mechanism
24
  )
25
  from transformers.image_utils import load_image
 
26
 
27
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
28
 
@@ -80,151 +80,126 @@ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
80
  ).to(device).eval()
81
 
82
  # Preprocessing functions for SmolDocling-256M
 
83
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
84
  """Add random padding to an image based on its size."""
85
  image = image.convert("RGB")
86
  width, height = image.size
87
- pad_w_percent = random.uniform(min_percent, max_percent)
88
- pad_h_percent = random.uniform(min_percent, max_percent)
89
- pad_w = int(width * pad_w_percent)
90
- pad_h = int(height * pad_h_percent)
91
- corner_pixel = image.getpixel((0, 0)) # Top-left corner
92
- padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
93
- return padded_image
94
 
95
  def normalize_values(text, target_max=500):
96
- """Normalize numerical values in text to a target maximum."""
97
- def normalize_list(values):
98
- max_value = max(values) if values else 1
99
- return [round((v / max_value) * target_max) for v in values]
100
 
101
- def process_match(match):
102
- num_list = ast.literal_eval(match.group(0))
103
- normalized = normalize_list(num_list)
104
- return "".join([f"<loc_{num}>" for num in normalized])
 
105
 
106
- pattern = r"\[([\d\.\s,]+)\]"
107
- normalized_text = re.sub(pattern, process_match, text)
108
- return normalized_text
109
 
110
  def downsample_video(video_path):
111
- """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
112
- vidcap = cv2.VideoCapture(video_path)
113
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
114
- fps = vidcap.get(cv2.CAP_PROP_FPS)
115
- frames = []
116
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
117
- for i in frame_indices:
118
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
119
- success, image = vidcap.read()
120
- if success:
121
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
122
- pil_image = Image.fromarray(image)
123
- timestamp = round(i / fps, 2)
124
- frames.append((pil_image, timestamp))
125
- vidcap.release()
126
  return frames
127
 
128
- # Dolphin-specific functions
 
129
  def model_chat(prompt, image):
130
- """Use Dolphin model for inference."""
131
- processor = processor_k
132
- model = model_k
133
- device = "cuda" if torch.cuda.is_available() else "cpu"
134
- inputs = processor(image, return_tensors="pt").to(device)
135
- pixel_values = inputs.pixel_values.half()
136
- prompt_inputs = processor.tokenizer(
137
- f"<s>{prompt} <Answer/>",
138
- add_special_tokens=False,
139
- return_tensors="pt"
140
- ).to(device)
141
-
142
- # Explicitly set past_key_values to None to align with new caching mechanism and avoid deprecated tuple warning
143
- outputs = model.generate(
144
- pixel_values=pixel_values,
145
- decoder_input_ids=prompt_inputs.input_ids,
146
- decoder_attention_mask=prompt_inputs.attention_mask,
147
- min_length=1,
148
- max_length=4096,
149
- pad_token_id=processor.tokenizer.pad_token_id,
150
- eos_token_id=processor.tokenizer.eos_token_id,
151
- use_cache=True,
152
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
 
153
  return_dict_in_generate=True,
154
- do_sample=False,
155
- num_beams=1,
156
- repetition_penalty=1.1,
157
- past_key_values=None # Added to prevent deprecated tuple handling
158
  )
159
- sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
160
- cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
161
- return cleaned
162
-
163
- def process_elements(layout_results, image):
164
- """Parse layout results and extract elements from the image."""
165
- # Placeholder parsing logic based on expected Dolphin output
166
- # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
167
  try:
168
- elements = ast.literal_eval(layout_results)
169
  except:
170
- elements = [] # Fallback if parsing fails
171
-
172
- recognition_results = []
173
- reading_order = 0
174
-
175
  for bbox, label in elements:
176
- try:
177
- x1, y1, x2, y2 = map(int, bbox)
178
- cropped = image.crop((x1, y1, x2, y2))
179
- if cropped.size[0] > 0 and cropped.size[1] > 0:
180
- if label == "text":
181
- text = model_chat("Read text in the image.", cropped)
182
- recognition_results.append({
183
- "label": label,
184
- "bbox": [x1, y1, x2, y2],
185
- "text": text.strip(),
186
- "reading_order": reading_order
187
- })
188
- elif label == "table":
189
- table_text = model_chat("Parse the table in the image.", cropped)
190
- recognition_results.append({
191
- "label": label,
192
- "bbox": [x1, y1, x2, y2],
193
- "text": table_text.strip(),
194
- "reading_order": reading_order
195
- })
196
- elif label == "figure":
197
- recognition_results.append({
198
- "label": label,
199
- "bbox": [x1, y1, x2, y2],
200
- "text": "[Figure]", # Placeholder for figure content
201
- "reading_order": reading_order
202
- })
203
- reading_order += 1
204
- except Exception as e:
205
- print(f"Error processing element: {e}")
206
  continue
207
-
208
- return recognition_results
209
-
210
- def generate_markdown(recognition_results):
211
- """Generate markdown from extracted elements."""
212
- markdown = ""
213
- for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
214
- if element["label"] == "text":
215
- markdown += f"{element['text']}\n\n"
216
- elif element["label"] == "table":
217
- markdown += f"**Table:**\n{element['text']}\n\n"
218
- elif element["label"] == "figure":
219
- markdown += f"{element['text']}\n\n"
220
- return markdown.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def process_image_with_dolphin(image):
223
- """Process a single image with Dolphin model."""
224
- layout_output = model_chat("Parse the reading order of this document.", image)
225
- elements = process_elements(layout_output, image)
226
- markdown_content = generate_markdown(elements)
227
- return markdown_content
228
 
229
  @spaces.GPU
230
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -233,83 +208,78 @@ def generate_image(model_name: str, text: str, image: Image.Image,
233
  top_p: float = 0.9,
234
  top_k: int = 50,
235
  repetition_penalty: float = 1.2):
236
- """Generate responses for image input using the selected model."""
237
  if model_name == "ByteDance-s-Dolphin":
238
  if image is None:
239
  yield "Please upload an image."
240
- return
241
- markdown_content = process_image_with_dolphin(image)
242
- yield markdown_content
243
- else:
244
- # Existing logic for other models
245
- if model_name == "Nanonets-OCR-s":
246
- processor = processor_m
247
- model = model_m
248
- elif model_name == "MonkeyOCR-Recognition":
249
- processor = processor_g
250
- model = model_g
251
- elif model_name == "SmolDocling-256M-preview":
252
- processor = processor_x
253
- model = model_x
254
  else:
255
- yield "Invalid model selected."
256
- return
257
-
258
- if image is None:
259
- yield "Please upload an image."
260
- return
261
-
262
- images = [image]
263
-
264
- if model_name == "SmolDocling-256M-preview":
265
- if "OTSL" in text or "code" in text:
266
- images = [add_random_padding(img) for img in images]
267
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
268
- text = normalize_values(text, target_max=500)
269
-
270
- messages = [
271
- {
272
- "role": "user",
273
- "content": [{"type": "image"} for _ in images] + [
274
- {"type": "text", "text": text}
275
- ]
276
- }
277
- ]
278
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
279
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
280
-
281
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
282
- generation_kwargs = {
283
- **inputs,
284
- "streamer": streamer,
285
- "max_new_tokens": max_new_tokens,
286
- "temperature": temperature,
287
- "top_p": top_p,
288
- "top_k": top_k,
289
- "repetition_penalty": repetition_penalty,
290
  }
291
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
292
- thread.start()
293
-
294
- buffer = ""
295
- full_output = ""
296
- for new_text in streamer:
297
- full_output += new_text
298
- buffer += new_text.replace("<|im_end|>", "")
299
- yield buffer
300
-
301
- if model_name == "SmolDocling-256M-preview":
302
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
303
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
304
- if "<chart>" in cleaned_output:
305
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
306
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
307
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
308
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
309
- markdown_output = doc.export_to_markdown()
310
- yield f"**MD Output:**\n\n{markdown_output}"
311
- else:
312
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  @spaces.GPU
315
  def generate_video(model_name: str, text: str, video_path: str,
@@ -318,97 +288,77 @@ def generate_video(model_name: str, text: str, video_path: str,
318
  top_p: float = 0.9,
319
  top_k: int = 50,
320
  repetition_penalty: float = 1.2):
321
- """Generate responses for video input using the selected model."""
322
  if model_name == "ByteDance-s-Dolphin":
323
- if video_path is None:
324
  yield "Please upload a video."
325
  return
326
- frames = downsample_video(video_path)
327
- markdown_contents = []
328
- for frame, _ in frames:
329
- markdown_content = process_image_with_dolphin(frame)
330
- markdown_contents.append(markdown_content)
331
- combined_markdown = "\n\n".join(markdown_contents)
332
- yield combined_markdown
 
 
 
 
 
333
  else:
334
- # Existing logic for other models
335
- if model_name == "Nanonets-OCR-s":
336
- processor = processor_m
337
- model = model_m
338
- elif model_name == "MonkeyOCR-Recognition":
339
- processor = processor_g
340
- model = model_g
341
- elif model_name == "SmolDocling-256M-preview":
342
- processor = processor_x
343
- model = model_x
344
- else:
345
- yield "Invalid model selected."
346
- return
347
-
348
- if video_path is None:
349
- yield "Please upload a video."
350
- return
351
-
352
- frames = downsample_video(video_path)
353
- images = [frame for frame, _ in frames]
354
-
355
- if model_name == "SmolDocling-256M-preview":
356
- if "OTSL" in text or "code" in text:
357
- images = [add_random_padding(img) for img in images]
358
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
359
- text = normalize_values(text, target_max=500)
360
-
361
- messages = [
362
- {
363
- "role": "user",
364
- "content": [{"type": "image"} for _ in images] + [
365
- {"type": "text", "text": text}
366
- ]
367
- }
368
- ]
369
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
370
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
371
-
372
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
373
- generation_kwargs = {
374
- **inputs,
375
- "streamer": streamer,
376
- "max_new_tokens": max_new_tokens,
377
- "temperature": temperature,
378
- "top_p": top_p,
379
- "top_k": top_k,
380
- "repetition_penalty": repetition_penalty,
381
  }
382
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
383
- thread.start()
384
-
385
- buffer = ""
386
- full_output = ""
387
- for new_text in streamer:
388
- full_output += new_text
389
- buffer += new_text.replace("<|im_end|>", "")
390
- yield buffer
391
-
392
- if model_name == "SmolDocling-256M-preview":
393
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
394
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
395
- if "<chart>" in cleaned_output:
396
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
397
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
398
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
399
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
400
- markdown_output = doc.export_to_markdown()
401
- yield f"**MD Output:**\n\n{markdown_output}"
402
- else:
403
- yield cleaned_output
404
-
405
- # Define examples for image and video inference
 
 
 
 
 
 
406
  image_examples = [
407
  ["Convert this page to docling", "images/1.png"],
408
  ["OCR the image", "images/2.jpg"],
409
  ["Convert this page to docling", "images/3.png"],
410
  ]
411
-
412
  video_examples = [
413
  ["Explain the ad in detail", "example/1.mp4"],
414
  ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
@@ -424,7 +374,6 @@ css = """
424
  }
425
  """
426
 
427
- # Create the Gradio Interface
428
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
429
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
430
  with gr.Row():
@@ -459,7 +408,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
459
  label="Select Model",
460
  value="Nanonets-OCR-s"
461
  )
462
-
463
  image_submit.click(
464
  fn=generate_image,
465
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
 
23
  )
24
  from transformers.image_utils import load_image
25
+ from transformers.generation import GenerationConfig
26
 
27
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
28
 
 
80
  ).to(device).eval()
81
 
82
  # Preprocessing functions for SmolDocling-256M
83
+
84
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
85
  """Add random padding to an image based on its size."""
86
  image = image.convert("RGB")
87
  width, height = image.size
88
+ pad_w = int(width * random.uniform(min_percent, max_percent))
89
+ pad_h = int(height * random.uniform(min_percent, max_percent))
90
+ corner_pixel = image.getpixel((0, 0))
91
+ padded = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
92
+ return padded
93
+
 
94
 
95
  def normalize_values(text, target_max=500):
96
+ """Normalize numerical lists in text to a target maximum."""
97
+ def norm_list(vals):
98
+ m = max(vals) if vals else 1
99
+ return [round(v / m * target_max) for v in vals]
100
 
101
+ def repl(m):
102
+ lst = ast.literal_eval(m.group(0))
103
+ return "".join(f"<loc_{n}>" for n in norm_list(lst))
104
+
105
+ return re.sub(r"\[([\d\.\s,]+)\]", repl, text)
106
 
 
 
 
107
 
108
  def downsample_video(video_path):
109
+ """Extract 10 evenly spaced frames (with timestamps) from a video."""
110
+ cap = cv2.VideoCapture(video_path)
111
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
112
+ fps = cap.get(cv2.CAP_PROP_FPS)
113
+ frames, indices = [], np.linspace(0, total - 1, 10, dtype=int)
114
+ for idx in indices:
115
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
116
+ ok, img = cap.read()
117
+ if not ok:
118
+ continue
119
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
120
+ frames.append((Image.fromarray(img), round(idx / fps, 2)))
121
+ cap.release()
 
 
122
  return frames
123
 
124
+ # Dolphin-specific inference
125
+
126
  def model_chat(prompt, image):
127
+ proc = processor_k
128
+ mdl = model_k
129
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
130
+
131
+ # encode image
132
+ inputs = proc(image, return_tensors="pt").to(device_str).pixel_values.half()
133
+ # encode prompt
134
+ pi = proc.tokenizer(f"<s>{prompt} <Answer/>", add_special_tokens=False, return_tensors="pt").to(device_str)
135
+
136
+ # build generation config
137
+ gen_cfg = GenerationConfig.from_model_config(mdl.config)
138
+ gen_cfg.max_length = 4096
139
+ gen_cfg.min_length = 1
140
+ gen_cfg.use_cache = True
141
+ gen_cfg.bad_words_ids = [[proc.tokenizer.unk_token_id]]
142
+ gen_cfg.num_beams = 1
143
+ gen_cfg.do_sample = False
144
+ gen_cfg.repetition_penalty = 1.1
145
+
146
+ out = mdl.generate(
147
+ pixel_values=inputs,
148
+ decoder_input_ids=pi.input_ids,
149
+ decoder_attention_mask=pi.attention_mask,
150
+ generation_config=gen_cfg,
151
  return_dict_in_generate=True,
 
 
 
 
152
  )
153
+ seq = proc.tokenizer.batch_decode(out.sequences, skip_special_tokens=False)[0]
154
+ return seq.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
155
+
156
+
157
+ def process_elements(layout_result, image):
 
 
 
158
  try:
159
+ elements = ast.literal_eval(layout_result)
160
  except:
161
+ elements = []
162
+
163
+ results, order = [], 0
 
 
164
  for bbox, label in elements:
165
+ x1, y1, x2, y2 = map(int, bbox)
166
+ crop = image.crop((x1, y1, x2, y2))
167
+ if crop.width == 0 or crop.height == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  continue
169
+
170
+ if label == "text":
171
+ txt = model_chat("Read text in the image.", crop)
172
+ elif label == "table":
173
+ txt = model_chat("Parse the table in the image.", crop)
174
+ else:
175
+ txt = "[Figure]"
176
+
177
+ results.append({
178
+ "label": label,
179
+ "bbox": [x1, y1, x2, y2],
180
+ "text": txt.strip(),
181
+ "reading_order": order
182
+ })
183
+ order += 1
184
+
185
+ return results
186
+
187
+
188
+ def generate_markdown(recog):
189
+ md = ""
190
+ for el in sorted(recog, key=lambda x: x["reading_order"]):
191
+ if el["label"] == "text":
192
+ md += el["text"] + "\n\n"
193
+ elif el["label"] == "table":
194
+ md += f"**Table:**\n{el['text']}\n\n"
195
+ else:
196
+ md += el["text"] + "\n\n"
197
+ return md.strip()
198
 
199
  def process_image_with_dolphin(image):
200
+ layout = model_chat("Parse the reading order of this document.", image)
201
+ elems = process_elements(layout, image)
202
+ return generate_markdown(elems)
 
 
203
 
204
  @spaces.GPU
205
  def generate_image(model_name: str, text: str, image: Image.Image,
 
208
  top_p: float = 0.9,
209
  top_k: int = 50,
210
  repetition_penalty: float = 1.2):
 
211
  if model_name == "ByteDance-s-Dolphin":
212
  if image is None:
213
  yield "Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  else:
215
+ yield process_image_with_dolphin(image)
216
+ return
217
+
218
+ if model_name == "Nanonets-OCR-s":
219
+ proc, mdl = processor_m, model_m
220
+ elif model_name == "SmolDocling-256M-preview":
221
+ proc, mdl = processor_x, model_x
222
+ elif model_name == "MonkeyOCR-Recognition":
223
+ proc, mdl = processor_g, model_g
224
+ else:
225
+ yield "Invalid model selected."
226
+ return
227
+
228
+ if image is None:
229
+ yield "Please upload an image."
230
+ return
231
+
232
+ imgs = [image]
233
+ if model_name == "SmolDocling-256M-preview":
234
+ if any(tok in text for tok in ["OTSL", "code"]):
235
+ imgs = [add_random_padding(img) for img in imgs]
236
+ if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
237
+ text = normalize_values(text, target_max=500)
238
+
239
+ messages = [
240
+ {"role":"user",
241
+ "content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
 
 
 
 
 
 
 
 
242
  }
243
+ ]
244
+ prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
245
+ inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
246
+
247
+ gen_cfg = GenerationConfig.from_model_config(mdl.config)
248
+ gen_cfg.max_new_tokens = max_new_tokens
249
+ gen_cfg.temperature = temperature
250
+ gen_cfg.top_p = top_p
251
+ gen_cfg.top_k = top_k
252
+ gen_cfg.repetition_penalty = repetition_penalty
253
+ gen_cfg.use_cache = True
254
+
255
+ streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
256
+ gen_kwargs = {
257
+ **inputs,
258
+ "streamer": streamer,
259
+ "generation_config": gen_cfg,
260
+ }
261
+
262
+ thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
263
+ thread.start()
264
+
265
+ buffer = ""
266
+ full_output = ""
267
+ for new_text in streamer:
268
+ full_output += new_text
269
+ buffer += new_text.replace("<|im_end|>", "")
270
+ yield buffer
271
+
272
+ if model_name == "SmolDocling-256M-preview":
273
+ cleaned = full_output.replace("<end_of_utterance>", "").strip()
274
+ if any(tag in cleaned for tag in ["<doctag>","<otsl>","<code>","<chart>","<formula>"]):
275
+ if "<chart>" in cleaned:
276
+ cleaned = cleaned.replace("<chart>","<otsl>").replace("</chart>","</otsl>")
277
+ cleaned = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned)
278
+ tags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned], imgs)
279
+ doc = DoclingDocument.load_from_doctags(tags_doc, document_name="Document")
280
+ yield f"**MD Output:**\n\n{doc.export_to_markdown()}"
281
+ else:
282
+ yield cleaned
283
 
284
  @spaces.GPU
285
  def generate_video(model_name: str, text: str, video_path: str,
 
288
  top_p: float = 0.9,
289
  top_k: int = 50,
290
  repetition_penalty: float = 1.2):
 
291
  if model_name == "ByteDance-s-Dolphin":
292
+ if not video_path:
293
  yield "Please upload a video."
294
  return
295
+ md_list = []
296
+ for frame, _ in downsample_video(video_path):
297
+ md_list.append(process_image_with_dolphin(frame))
298
+ yield "\n\n".join(md_list)
299
+ return
300
+
301
+ if model_name == "Nanonets-OCR-s":
302
+ proc, mdl = processor_m, model_m
303
+ elif model_name == "SmolDocling-256M-preview":
304
+ proc, mdl = processor_x, model_x
305
+ elif model_name == "MonkeyOCR-Recognition":
306
+ proc, mdl = processor_g, model_g
307
  else:
308
+ yield "Invalid model selected."
309
+ return
310
+
311
+ if not video_path:
312
+ yield "Please upload a video."
313
+ return
314
+
315
+ frames = [f for f, _ in downsample_video(video_path)]
316
+ imgs = frames
317
+ if model_name == "SmolDocling-256M-preview":
318
+ if any(tok in text for tok in ["OTSL", "code"]):
319
+ imgs = [add_random_padding(img) for img in imgs]
320
+ if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
321
+ pm.text.normalize_values(text, target_max=500)
322
+
323
+ messages = [
324
+ {"role":"user",
325
+ "content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  }
327
+ ]
328
+ prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
329
+ inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
330
+
331
+ gen_cfg = GenerationConfig.from_model_config(mdl.config)
332
+ gen_cfg.max_new_tokens = max_new_tokens
333
+ gen_cfg.temperature = temperature
334
+ gen_cfg.top_p = top_p
335
+ gen_cfg.top_k = top_k
336
+ gen_cfg.repetition_penalty = repetition_penalty
337
+ gen_cfg.use_cache = True
338
+
339
+ streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
340
+ gen_kwargs = {
341
+ **inputs,
342
+ "streamer": streamer,
343
+ "generation_config": gen_cfg,
344
+ }
345
+
346
+ thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
347
+ thread.start()
348
+
349
+ buff = ""
350
+ full = ""
351
+ for nt in streamer:
352
+ full += nt
353
+ buff += nt.replace("<|im_end|>", "")
354
+ yield buff
355
+
356
+ # Gradio UI
357
  image_examples = [
358
  ["Convert this page to docling", "images/1.png"],
359
  ["OCR the image", "images/2.jpg"],
360
  ["Convert this page to docling", "images/3.png"],
361
  ]
 
362
  video_examples = [
363
  ["Explain the ad in detail", "example/1.mp4"],
364
  ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
 
374
  }
375
  """
376
 
 
377
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
378
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
379
  with gr.Row():
 
408
  label="Select Model",
409
  value="Nanonets-OCR-s"
410
  )
 
411
  image_submit.click(
412
  fn=generate_image,
413
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],