prithivMLmods commited on
Commit
59ff1ca
·
verified ·
1 Parent(s): db355c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -228
app.py CHANGED
@@ -22,7 +22,6 @@ from transformers import (
22
  TextIteratorStreamer,
23
  )
24
  from transformers.image_utils import load_image
25
- from transformers.generation import GenerationConfig
26
 
27
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
28
 
@@ -80,126 +79,148 @@ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
80
  ).to(device).eval()
81
 
82
  # Preprocessing functions for SmolDocling-256M
83
-
84
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
85
  """Add random padding to an image based on its size."""
86
  image = image.convert("RGB")
87
  width, height = image.size
88
- pad_w = int(width * random.uniform(min_percent, max_percent))
89
- pad_h = int(height * random.uniform(min_percent, max_percent))
90
- corner_pixel = image.getpixel((0, 0))
91
- padded = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
92
- return padded
93
-
 
94
 
95
  def normalize_values(text, target_max=500):
96
- """Normalize numerical lists in text to a target maximum."""
97
- def norm_list(vals):
98
- m = max(vals) if vals else 1
99
- return [round(v / m * target_max) for v in vals]
100
 
101
- def repl(m):
102
- lst = ast.literal_eval(m.group(0))
103
- return "".join(f"<loc_{n}>" for n in norm_list(lst))
104
-
105
- return re.sub(r"\[([\d\.\s,]+)\]", repl, text)
106
 
 
 
 
107
 
108
  def downsample_video(video_path):
109
- """Extract 10 evenly spaced frames (with timestamps) from a video."""
110
- cap = cv2.VideoCapture(video_path)
111
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
112
- fps = cap.get(cv2.CAP_PROP_FPS)
113
- frames, indices = [], np.linspace(0, total - 1, 10, dtype=int)
114
- for idx in indices:
115
- cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
116
- ok, img = cap.read()
117
- if not ok:
118
- continue
119
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
120
- frames.append((Image.fromarray(img), round(idx / fps, 2)))
121
- cap.release()
 
 
122
  return frames
123
 
124
- # Dolphin-specific inference
125
-
126
  def model_chat(prompt, image):
127
- proc = processor_k
128
- mdl = model_k
129
- device_str = "cuda" if torch.cuda.is_available() else "cpu"
130
-
131
- # encode image
132
- inputs = proc(image, return_tensors="pt").to(device_str).pixel_values.half()
133
- # encode prompt
134
- pi = proc.tokenizer(f"<s>{prompt} <Answer/>", add_special_tokens=False, return_tensors="pt").to(device_str)
135
-
136
- # build generation config
137
- gen_cfg = GenerationConfig.from_model_config(mdl.config)
138
- gen_cfg.max_length = 4096
139
- gen_cfg.min_length = 1
140
- gen_cfg.use_cache = True
141
- gen_cfg.bad_words_ids = [[proc.tokenizer.unk_token_id]]
142
- gen_cfg.num_beams = 1
143
- gen_cfg.do_sample = False
144
- gen_cfg.repetition_penalty = 1.1
145
-
146
- out = mdl.generate(
147
- pixel_values=inputs,
148
- decoder_input_ids=pi.input_ids,
149
- decoder_attention_mask=pi.attention_mask,
150
- generation_config=gen_cfg,
151
  return_dict_in_generate=True,
 
 
 
152
  )
153
- seq = proc.tokenizer.batch_decode(out.sequences, skip_special_tokens=False)[0]
154
- return seq.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
155
-
156
-
157
- def process_elements(layout_result, image):
 
 
 
158
  try:
159
- elements = ast.literal_eval(layout_result)
160
  except:
161
- elements = []
162
-
163
- results, order = [], 0
 
 
164
  for bbox, label in elements:
165
- x1, y1, x2, y2 = map(int, bbox)
166
- crop = image.crop((x1, y1, x2, y2))
167
- if crop.width == 0 or crop.height == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  continue
169
-
170
- if label == "text":
171
- txt = model_chat("Read text in the image.", crop)
172
- elif label == "table":
173
- txt = model_chat("Parse the table in the image.", crop)
174
- else:
175
- txt = "[Figure]"
176
-
177
- results.append({
178
- "label": label,
179
- "bbox": [x1, y1, x2, y2],
180
- "text": txt.strip(),
181
- "reading_order": order
182
- })
183
- order += 1
184
-
185
- return results
186
-
187
-
188
- def generate_markdown(recog):
189
- md = ""
190
- for el in sorted(recog, key=lambda x: x["reading_order"]):
191
- if el["label"] == "text":
192
- md += el["text"] + "\n\n"
193
- elif el["label"] == "table":
194
- md += f"**Table:**\n{el['text']}\n\n"
195
- else:
196
- md += el["text"] + "\n\n"
197
- return md.strip()
198
 
199
  def process_image_with_dolphin(image):
200
- layout = model_chat("Parse the reading order of this document.", image)
201
- elems = process_elements(layout, image)
202
- return generate_markdown(elems)
 
 
203
 
204
  @spaces.GPU
205
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -208,78 +229,83 @@ def generate_image(model_name: str, text: str, image: Image.Image,
208
  top_p: float = 0.9,
209
  top_k: int = 50,
210
  repetition_penalty: float = 1.2):
 
211
  if model_name == "ByteDance-s-Dolphin":
212
  if image is None:
213
  yield "Please upload an image."
214
- else:
215
- yield process_image_with_dolphin(image)
216
- return
217
-
218
- if model_name == "Nanonets-OCR-s":
219
- proc, mdl = processor_m, model_m
220
- elif model_name == "SmolDocling-256M-preview":
221
- proc, mdl = processor_x, model_x
222
- elif model_name == "MonkeyOCR-Recognition":
223
- proc, mdl = processor_g, model_g
224
  else:
225
- yield "Invalid model selected."
226
- return
227
-
228
- if image is None:
229
- yield "Please upload an image."
230
- return
231
-
232
- imgs = [image]
233
- if model_name == "SmolDocling-256M-preview":
234
- if any(tok in text for tok in ["OTSL", "code"]):
235
- imgs = [add_random_padding(img) for img in imgs]
236
- if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
237
- text = normalize_values(text, target_max=500)
238
-
239
- messages = [
240
- {"role":"user",
241
- "content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
242
- }
243
- ]
244
- prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
245
- inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
246
-
247
- gen_cfg = GenerationConfig.from_model_config(mdl.config)
248
- gen_cfg.max_new_tokens = max_new_tokens
249
- gen_cfg.temperature = temperature
250
- gen_cfg.top_p = top_p
251
- gen_cfg.top_k = top_k
252
- gen_cfg.repetition_penalty = repetition_penalty
253
- gen_cfg.use_cache = True
254
-
255
- streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
256
- gen_kwargs = {
257
- **inputs,
258
- "streamer": streamer,
259
- "generation_config": gen_cfg,
260
- }
261
-
262
- thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
263
- thread.start()
264
-
265
- buffer = ""
266
- full_output = ""
267
- for new_text in streamer:
268
- full_output += new_text
269
- buffer += new_text.replace("<|im_end|>", "")
270
- yield buffer
271
-
272
- if model_name == "SmolDocling-256M-preview":
273
- cleaned = full_output.replace("<end_of_utterance>", "").strip()
274
- if any(tag in cleaned for tag in ["<doctag>","<otsl>","<code>","<chart>","<formula>"]):
275
- if "<chart>" in cleaned:
276
- cleaned = cleaned.replace("<chart>","<otsl>").replace("</chart>","</otsl>")
277
- cleaned = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned)
278
- tags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned], imgs)
279
- doc = DoclingDocument.load_from_doctags(tags_doc, document_name="Document")
280
- yield f"**MD Output:**\n\n{doc.export_to_markdown()}"
281
  else:
282
- yield cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  @spaces.GPU
285
  def generate_video(model_name: str, text: str, video_path: str,
@@ -288,77 +314,97 @@ def generate_video(model_name: str, text: str, video_path: str,
288
  top_p: float = 0.9,
289
  top_k: int = 50,
290
  repetition_penalty: float = 1.2):
 
291
  if model_name == "ByteDance-s-Dolphin":
292
- if not video_path:
293
  yield "Please upload a video."
294
  return
295
- md_list = []
296
- for frame, _ in downsample_video(video_path):
297
- md_list.append(process_image_with_dolphin(frame))
298
- yield "\n\n".join(md_list)
299
- return
300
-
301
- if model_name == "Nanonets-OCR-s":
302
- proc, mdl = processor_m, model_m
303
- elif model_name == "SmolDocling-256M-preview":
304
- proc, mdl = processor_x, model_x
305
- elif model_name == "MonkeyOCR-Recognition":
306
- proc, mdl = processor_g, model_g
307
  else:
308
- yield "Invalid model selected."
309
- return
310
-
311
- if not video_path:
312
- yield "Please upload a video."
313
- return
314
-
315
- frames = [f for f, _ in downsample_video(video_path)]
316
- imgs = frames
317
- if model_name == "SmolDocling-256M-preview":
318
- if any(tok in text for tok in ["OTSL", "code"]):
319
- imgs = [add_random_padding(img) for img in imgs]
320
- if any(tok in text for tok in ["OCR at text", "Identify element", "formula"]):
321
- text = normalize_values(text, target_max=500)
322
-
323
- messages = [
324
- {"role":"user",
325
- "content":[{"type":"image"} for _ in imgs] + [{"type":"text","text":text}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  }
327
- ]
328
- prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
329
- inputs = proc(text=prompt, images=imgs, return_tensors="pt").to(device)
330
-
331
- gen_cfg = GenerationConfig.from_model_config(mdl.config)
332
- gen_cfg.max_new_tokens = max_new_tokens
333
- gen_cfg.temperature = temperature
334
- gen_cfg.top_p = top_p
335
- gen_cfg.top_k = top_k
336
- gen_cfg.repetition_penalty = repetition_penalty
337
- gen_cfg.use_cache = True
338
-
339
- streamer = TextIteratorStreamer(proc, skip_prompt=True, skip_special_tokens=True)
340
- gen_kwargs = {
341
- **inputs,
342
- "streamer": streamer,
343
- "generation_config": gen_cfg,
344
- }
345
-
346
- thread = Thread(target=mdl.generate, kwargs=gen_kwargs)
347
- thread.start()
348
-
349
- buff = ""
350
- full = ""
351
- for nt in streamer:
352
- full += nt
353
- buff += nt.replace("<|im_end|>", "")
354
- yield buff
355
-
356
- # Gradio UI
357
  image_examples = [
358
  ["Convert this page to docling", "images/1.png"],
359
  ["OCR the image", "images/2.jpg"],
360
  ["Convert this page to docling", "images/3.png"],
361
  ]
 
362
  video_examples = [
363
  ["Explain the ad in detail", "example/1.mp4"],
364
  ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
@@ -374,6 +420,7 @@ css = """
374
  }
375
  """
376
 
 
377
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
378
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
379
  with gr.Row():
@@ -408,6 +455,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
408
  label="Select Model",
409
  value="Nanonets-OCR-s"
410
  )
 
411
  image_submit.click(
412
  fn=generate_image,
413
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
22
  TextIteratorStreamer,
23
  )
24
  from transformers.image_utils import load_image
 
25
 
26
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
27
 
 
79
  ).to(device).eval()
80
 
81
  # Preprocessing functions for SmolDocling-256M
 
82
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
83
  """Add random padding to an image based on its size."""
84
  image = image.convert("RGB")
85
  width, height = image.size
86
+ pad_w_percent = random.uniform(min_percent, max_percent)
87
+ pad_h_percent = random.uniform(min_percent, max_percent)
88
+ pad_w = int(width * pad_w_percent)
89
+ pad_h = int(height * pad_h_percent)
90
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
91
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
92
+ return padded_image
93
 
94
  def normalize_values(text, target_max=500):
95
+ """Normalize numerical values in text to a target maximum."""
96
+ def normalize_list(values):
97
+ max_value = max(values) if values else 1
98
+ return [round((v / max_value) * target_max) for v in values]
99
 
100
+ def process_match(match):
101
+ num_list = ast.literal_eval(match.group(0))
102
+ normalized = normalize_list(num_list)
103
+ return "".join([f"<loc_{num}>" for num in normalized])
 
104
 
105
+ pattern = r"\[([\d\.\s,]+)\]"
106
+ normalized_text = re.sub(pattern, process_match, text)
107
+ return normalized_text
108
 
109
  def downsample_video(video_path):
110
+ """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
111
+ vidcap = cv2.VideoCapture(video_path)
112
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
113
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
114
+ frames = []
115
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
116
+ for i in frame_indices:
117
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
118
+ success, image = vidcap.read()
119
+ if success:
120
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
121
+ pil_image = Image.fromarray(image)
122
+ timestamp = round(i / fps, 2)
123
+ frames.append((pil_image, timestamp))
124
+ vidcap.release()
125
  return frames
126
 
127
+ # Dolphin-specific functions
 
128
  def model_chat(prompt, image):
129
+ """Use Dolphin model for inference."""
130
+ processor = processor_k
131
+ model = model_k
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ inputs = processor(image, return_tensors="pt").to(device)
134
+ pixel_values = inputs.pixel_values.half()
135
+ prompt_inputs = processor.tokenizer(
136
+ f"<s>{prompt} <Answer/>",
137
+ add_special_tokens=False,
138
+ return_tensors="pt"
139
+ ).to(device)
140
+ outputs = model.generate(
141
+ pixel_values=pixel_values,
142
+ decoder_input_ids=prompt_inputs.input_ids,
143
+ decoder_attention_mask=prompt_inputs.attention_mask,
144
+ min_length=1,
145
+ max_length=4096,
146
+ pad_token_id=processor.tokenizer.pad_token_id,
147
+ eos_token_id=processor.tokenizer.eos_token_id,
148
+ use_cache=False, # Changed to False to avoid deprecation warning
149
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
 
 
 
150
  return_dict_in_generate=True,
151
+ do_sample=False,
152
+ num_beams=1,
153
+ repetition_penalty=1.1
154
  )
155
+ sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
+ cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
157
+ return cleaned
158
+
159
+ def process_elements(layout_results, image):
160
+ """Parse layout results and extract elements from the image."""
161
+ # Placeholder parsing logic based on expected Dolphin output
162
+ # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
163
  try:
164
+ elements = ast.literal_eval(layout_results)
165
  except:
166
+ elements = [] # Fallback if parsing fails
167
+
168
+ recognition_results = []
169
+ reading_order = 0
170
+
171
  for bbox, label in elements:
172
+ try:
173
+ x1, y1, x2, y2 = map(int, bbox)
174
+ cropped = image.crop((x1, y1, x2, y2))
175
+ if cropped.size[0] > 0 and cropped.size[1] > 0:
176
+ if label == "text":
177
+ text = model_chat("Read text in the image.", cropped)
178
+ recognition_results.append({
179
+ "label": label,
180
+ "bbox": [x1, y1, x2, y2],
181
+ "text": text.strip(),
182
+ "reading_order": reading_order
183
+ })
184
+ elif label == "table":
185
+ table_text = model_chat("Parse the table in the image.", cropped)
186
+ recognition_results.append({
187
+ "label": label,
188
+ "bbox": [x1, y1, x2, y2],
189
+ "text": table_text.strip(),
190
+ "reading_order": reading_order
191
+ })
192
+ elif label == "figure":
193
+ recognition_results.append({
194
+ "label": label,
195
+ "bbox": [x1, y1, x2, y2],
196
+ "text": "[Figure]", # Placeholder for figure content
197
+ "reading_order": reading_order
198
+ })
199
+ reading_order += 1
200
+ except Exception as e:
201
+ print(f"Error processing element: {e}")
202
  continue
203
+
204
+ return recognition_results
205
+
206
+ def generate_markdown(recognition_results):
207
+ """Generate markdown from extracted elements."""
208
+ markdown = ""
209
+ for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
210
+ if element["label"] == "text":
211
+ markdown += f"{element['text']}\n\n"
212
+ elif element["label"] == "table":
213
+ markdown += f"**Table:**\n{element['text']}\n\n"
214
+ elif element["label"] == "figure":
215
+ markdown += f"{element['text']}\n\n"
216
+ return markdown.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def process_image_with_dolphin(image):
219
+ """Process a single image with Dolphin model."""
220
+ layout_output = model_chat("Parse the reading order of this document.", image)
221
+ elements = process_elements(layout_output, image)
222
+ markdown_content = generate_markdown(elements)
223
+ return markdown_content
224
 
225
  @spaces.GPU
226
  def generate_image(model_name: str, text: str, image: Image.Image,
 
229
  top_p: float = 0.9,
230
  top_k: int = 50,
231
  repetition_penalty: float = 1.2):
232
+ """Generate responses for image input using the selected model."""
233
  if model_name == "ByteDance-s-Dolphin":
234
  if image is None:
235
  yield "Please upload an image."
236
+ return
237
+ markdown_content = process_image_with_dolphin(image)
238
+ yield markdown_content
 
 
 
 
 
 
 
239
  else:
240
+ # Existing logic for other models
241
+ if model_name == "Nanonets-OCR-s":
242
+ processor = processor_m
243
+ model = model_m
244
+ elif model_name == "MonkeyOCR-Recognition":
245
+ processor = processor_g
246
+ model = model_g
247
+ elif model_name == "SmolDocling-256M-preview":
248
+ processor = processor_x
249
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  else:
251
+ yield "Invalid model selected."
252
+ return
253
+
254
+ if image is None:
255
+ yield "Please upload an image."
256
+ return
257
+
258
+ images = [image]
259
+
260
+ if model_name == "SmolDocling-256M-preview":
261
+ if "OTSL" in text or "code" in text:
262
+ images = [add_random_padding(img) for img in images]
263
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
264
+ text = normalize_values(text, target_max=500)
265
+
266
+ messages = [
267
+ {
268
+ "role": "user",
269
+ "content": [{"type": "image"} for _ in images] + [
270
+ {"type": "text", "text": text}
271
+ ]
272
+ }
273
+ ]
274
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
275
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
276
+
277
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
278
+ generation_kwargs = {
279
+ **inputs,
280
+ "streamer": streamer,
281
+ "max_new_tokens": max_new_tokens,
282
+ "temperature": temperature,
283
+ "top_p": top_p,
284
+ "top_k": top_k,
285
+ "repetition_penalty": repetition_penalty,
286
+ }
287
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
288
+ thread.start()
289
+
290
+ buffer = ""
291
+ full_output = ""
292
+ for new_text in streamer:
293
+ full_output += new_text
294
+ buffer += new_text.replace("<|im_end|>", "")
295
+ yield buffer
296
+
297
+ if model_name == "SmolDocling-256M-preview":
298
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
299
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
300
+ if "<chart>" in cleaned_output:
301
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
302
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
303
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
304
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
305
+ markdown_output = doc.export_to_markdown()
306
+ yield f"**MD Output:**\n\n{markdown_output}"
307
+ else:
308
+ yield cleaned_output
309
 
310
  @spaces.GPU
311
  def generate_video(model_name: str, text: str, video_path: str,
 
314
  top_p: float = 0.9,
315
  top_k: int = 50,
316
  repetition_penalty: float = 1.2):
317
+ """Generate responses for video input using the selected model."""
318
  if model_name == "ByteDance-s-Dolphin":
319
+ if video_path is None:
320
  yield "Please upload a video."
321
  return
322
+ frames = downsample_video(video_path)
323
+ markdown_contents = []
324
+ for frame, _ in frames:
325
+ markdown_content = process_image_with_dolphin(frame)
326
+ markdown_contents.append(markdown_content)
327
+ combined_markdown = "\n\n".join(markdown_contents)
328
+ yield combined_markdown
 
 
 
 
 
329
  else:
330
+ # Existing logic for other models
331
+ if model_name == "Nanonets-OCR-s":
332
+ processor = processor_m
333
+ model = model_m
334
+ elif model_name == "MonkeyOCR-Recognition":
335
+ processor = processor_g
336
+ model = model_g
337
+ elif model_name == "SmolDocling-256M-preview":
338
+ processor = processor_x
339
+ model = model_x
340
+ else:
341
+ yield "Invalid model selected."
342
+ return
343
+
344
+ if video_path is None:
345
+ yield "Please upload a video."
346
+ return
347
+
348
+ frames = downsample_video(video_path)
349
+ images = [frame for frame, _ in frames]
350
+
351
+ if model_name == "SmolDocling-256M-preview":
352
+ if "OTSL" in text or "code" in text:
353
+ images = [add_random_padding(img) for img in images]
354
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
355
+ text = normalize_values(text, target_max=500)
356
+
357
+ messages = [
358
+ {
359
+ "role": "user",
360
+ "content": [{"type": "image"} for _ in images] + [
361
+ {"type": "text", "text": text}
362
+ ]
363
+ }
364
+ ]
365
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
+
368
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
+ generation_kwargs = {
370
+ **inputs,
371
+ "streamer": streamer,
372
+ "max_new_tokens": max_new_tokens,
373
+ "temperature": temperature,
374
+ "top_p": top_p,
375
+ "top_k": top_k,
376
+ "repetition_penalty": repetition_penalty,
377
  }
378
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
379
+ thread.start()
380
+
381
+ buffer = ""
382
+ full_output = ""
383
+ for new_text in streamer:
384
+ full_output += new_text
385
+ buffer += new_text.replace("<|im_end|>", "")
386
+ yield buffer
387
+
388
+ if model_name == "SmolDocling-256M-preview":
389
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
390
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
391
+ if "<chart>" in cleaned_output:
392
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
393
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
394
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
395
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
396
+ markdown_output = doc.export_to_markdown()
397
+ yield f"**MD Output:**\n\n{markdown_output}"
398
+ else:
399
+ yield cleaned_output
400
+
401
+ # Define examples for image and video inference
 
 
 
 
 
 
402
  image_examples = [
403
  ["Convert this page to docling", "images/1.png"],
404
  ["OCR the image", "images/2.jpg"],
405
  ["Convert this page to docling", "images/3.png"],
406
  ]
407
+
408
  video_examples = [
409
  ["Explain the ad in detail", "example/1.mp4"],
410
  ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
 
420
  }
421
  """
422
 
423
+ # Create the Gradio Interface
424
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
425
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
426
  with gr.Row():
 
455
  label="Select Model",
456
  value="Nanonets-OCR-s"
457
  )
458
+
459
  image_submit.click(
460
  fn=generate_image,
461
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],