prithivMLmods commited on
Commit
42b280c
·
verified ·
1 Parent(s): 451f8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -408
app.py CHANGED
@@ -10,26 +10,17 @@ import gradio as gr
10
  import spaces
11
  import torch
12
  import numpy as np
13
- from PIL import Image, ImageOps
14
  import cv2
15
- import pymupdf
16
- import io
17
 
18
  from transformers import (
19
  Qwen2VLForConditionalGeneration,
20
- VisionEncoderDecoderModel,
21
- AutoModelForVision2Seq,
22
  AutoProcessor,
23
  TextIteratorStreamer,
24
  )
25
  from transformers.image_utils import load_image
26
 
27
- from docling_core.types.doc import DoclingDocument, DocTagsDocument
28
-
29
- import re
30
- import ast
31
- import html
32
-
33
  # Constants for text generation
34
  MAX_MAX_NEW_TOKENS = 2048
35
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -37,71 +28,29 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
37
 
38
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
 
40
- # Global variables for Dolphin model
41
- model_k = None
42
- processor_k = None
43
- tokenizer_k = None
44
-
45
- # Load models
46
- def initialize_models():
47
- global model_k, processor_k, tokenizer_k
48
- # Load olmOCR-7B-0225-preview
49
- MODEL_ID_M = "allenai/olmOCR-7B-0225-preview"
50
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
51
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
52
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
53
- ).to(device).eval()
54
-
55
- # Load ByteDance's Dolphin
56
- MODEL_ID_K = "ByteDance/Dolphin"
57
- processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
58
- if model_k is None:
59
- model_k = VisionEncoderDecoderModel.from_pretrained(
60
- MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16
61
- ).to(device).eval()
62
- tokenizer_k = processor_k.tokenizer
63
-
64
- # Load SmolDocling-256M-preview
65
- MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
66
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
67
- model_x = AutoModelForVision2Seq.from_pretrained(
68
- MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16
69
- ).to(device).eval()
70
-
71
- return processor_m, model_m, processor_x, model_x
72
-
73
- processor_m, model_m, processor_x, model_x = initialize_models()
74
-
75
- # Preprocessing functions for SmolDocling-256M
76
- def add_random_padding(image, min_percent=0.1, max_percent=0.10):
77
- """Add random padding to an image based on its size."""
78
- image = image.convert("RGB")
79
- width, height = image.size
80
- pad_w_percent = random.uniform(min_percent, max_percent)
81
- pad_h_percent = random.uniform(min_percent, max_percent)
82
- pad_w = int(width * pad_w_percent)
83
- pad_h = int(height * pad_h_percent)
84
- corner_pixel = image.getpixel((0, 0))
85
- padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
86
- return padded_image
87
-
88
- def normalize_values(text, target_max=500):
89
- """Normalize numerical values in text to a target maximum."""
90
- def normalize_list(values):
91
- max_value = max(values) if values else 1
92
- return [round((v / max_value) * target_max) for v in values]
93
-
94
- def process_match(match):
95
- num_list = ast.literal_eval(match.group(0))
96
- normalized = normalize_list(num_list)
97
- return "".join([f"<loc_{num}>" for num in normalized])
98
-
99
- pattern = r"\[([\d\.\s,]+)\]"
100
- normalized_text = re.sub(pattern, process_match, text)
101
- return normalized_text
102
 
103
  def downsample_video(video_path):
104
- """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
 
 
 
105
  vidcap = cv2.VideoCapture(video_path)
106
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
107
  fps = vidcap.get(cv2.CAP_PROP_FPS)
@@ -118,343 +67,128 @@ def downsample_video(video_path):
118
  vidcap.release()
119
  return frames
120
 
121
- # Dolphin-specific functions
122
- @spaces.GPU
123
- def model_chat(prompt, image, is_batch=False):
124
- """Use Dolphin model for inference, supporting both single and batch processing."""
125
- global model_k, processor_k, tokenizer_k
126
- if model_k is None:
127
- initialize_models()
128
-
129
- if not is_batch:
130
- images = [image]
131
- prompts = [prompt]
132
- else:
133
- images = image
134
- prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
135
-
136
- inputs = processor_k(images, return_tensors="pt", padding=True).to(device)
137
- pixel_values = inputs.pixel_values.half()
138
-
139
- prompts = [f"<s>{p} <Answer/>" for p in prompts]
140
- prompt_inputs = tokenizer_k(
141
- prompts, add_special_tokens=False, return_tensors="pt", padding=True
142
- ).to(device)
143
-
144
- outputs = model_k.generate(
145
- pixel_values=pixel_values,
146
- decoder_input_ids=prompt_inputs.input_ids,
147
- decoder_attention_mask=prompt_inputs.attention_mask,
148
- min_length=1,
149
- max_length=4096,
150
- pad_token_id=tokenizer_k.pad_token_id,
151
- eos_token_id=tokenizer_k.eos_token_id,
152
- use_cache=True,
153
- bad_words_ids=[[tokenizer_k.unk_token_id]],
154
- return_dict_in_generate=True,
155
- do_sample=False,
156
- num_beams=1,
157
- repetition_penalty=1.1
158
- )
159
- sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False)
160
-
161
- results = []
162
- for i, sequence in enumerate(sequences):
163
- cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
164
- results.append(cleaned)
165
-
166
- return results[0] if not is_batch else results
167
-
168
- @spaces.GPU
169
- def process_element_batch(elements, prompt, max_batch_size=16):
170
- """Process a batch of elements with the same prompt."""
171
- results = []
172
- batch_size = min(len(elements), max_batch_size)
173
-
174
- for i in range(0, len(elements), batch_size):
175
- batch_elements = elements[i:i + batch_size]
176
- crops_list = [elem["crop"] for elem in batch_elements]
177
- prompts_list = [prompt] * len(crops_list)
178
-
179
- batch_results = model_chat(prompts_list, crops_list, is_batch=True)
180
-
181
- for j, result in enumerate(batch_results):
182
- elem = batch_elements[j]
183
- results.append({
184
- "label": elem["label"],
185
- "bbox": elem["bbox"],
186
- "text": result.strip(),
187
- "reading_order": elem["reading_order"],
188
- })
189
-
190
- return results
191
-
192
- def process_elements(layout_results, image):
193
- """Parse layout results and extract elements from the image."""
194
- try:
195
- elements = ast.literal_eval(layout_results)
196
- except:
197
- elements = []
198
-
199
- text_elements = []
200
- table_elements = []
201
- figure_results = []
202
- reading_order = 0
203
-
204
- for bbox, label in elements:
205
- try:
206
- x1, y1, x2, y2 = map(int, bbox)
207
- cropped = image.crop((x1, y1, x2, y2))
208
- if cropped.size[0] > 0 and cropped.size[1] > 0:
209
- element_info = {
210
- "crop": cropped,
211
- "label": label,
212
- "bbox": [x1, y1, x2, y2],
213
- "reading_order": reading_order,
214
- }
215
- if label == "text":
216
- text_elements.append(element_info)
217
- elif label == "table":
218
- table_elements.append(element_info)
219
- elif label == "figure":
220
- figure_results.append({
221
- "label": label,
222
- "bbox": [x1, y1, x2, y2],
223
- "text": "[Figure]",
224
- "reading_order": reading_order
225
- })
226
- reading_order += 1
227
- except Exception as e:
228
- print(f"Error processing element: {e}")
229
- continue
230
-
231
- recognition_results = figure_results.copy()
232
-
233
- if text_elements:
234
- text_results = process_element_batch(text_elements, "Read text in the image.")
235
- recognition_results.extend(text_results)
236
-
237
- if table_elements:
238
- table_results = process_element_batch(table_elements, "Parse the table in the image.")
239
- recognition_results.extend(table_results)
240
-
241
- recognition_results.sort(key=lambda x: x["reading_order"])
242
- return recognition_results
243
-
244
- def generate_markdown(recognition_results):
245
- """Generate markdown from extracted elements."""
246
- markdown = ""
247
- for element in recognition_results:
248
- if element["label"] == "text":
249
- markdown += f"{element['text']}\n\n"
250
- elif element["label"] == "table":
251
- markdown += f"**Table:**\n{element['text']}\n\n"
252
- elif element["label"] == "figure":
253
- markdown += f"{element['text']}\n\n"
254
- return markdown.strip()
255
-
256
- def convert_to_image(image):
257
- """Convert uploaded file to PIL Image, handling PDFs by extracting the first page."""
258
- if isinstance(image, str): # File path from Gradio
259
- if image.lower().endswith('.pdf'):
260
- doc = pymupdf.open(image)
261
- page = doc[0]
262
- pix = page.get_pixmap()
263
- img_data = pix.tobytes("png")
264
- pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
265
- doc.close()
266
- return pil_image
267
- else:
268
- return Image.open(image).convert("RGB")
269
- elif isinstance(image, Image.Image): # Already a PIL Image
270
- return image.convert("RGB")
271
- return None
272
-
273
- def process_image_with_dolphin(image):
274
- """Process a single image with Dolphin model."""
275
- pil_image = convert_to_image(image)
276
- if pil_image is None:
277
- return "Error: Unable to process the uploaded file."
278
- layout_output = model_chat("Parse the reading order of this document.", pil_image)
279
- elements = process_elements(layout_output, pil_image)
280
- markdown_content = generate_markdown(elements)
281
- return markdown_content
282
-
283
  @spaces.GPU
284
  def generate_image(model_name: str, text: str, image: Image.Image,
285
- max_new_tokens: int = 1024, temperature: float = 0.6,
286
- top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
287
- """Generate responses for image input using the selected model."""
288
- if model_name == "ByteDance-s-Dolphin":
289
- if image is None:
290
- yield "Please upload an image or PDF (first page will be processed)."
291
- return
292
- markdown_content = process_image_with_dolphin(image)
293
- yield markdown_content
 
 
 
 
 
294
  else:
295
- if model_name == "olmOCR-7B-0225-preview":
296
- processor = processor_m
297
- model = model_m
298
- elif model_name == "SmolDocling-256M-preview":
299
- processor = processor_x
300
- model = model_x
301
- else:
302
- yield "Invalid model selected."
303
- return
304
-
305
- if image is None:
306
- yield "Please upload an image."
307
- return
308
-
309
- images = [convert_to_image(image)]
310
- if images[0] is None:
311
- yield "Error: Unable to process the uploaded file."
312
- return
313
-
314
- if model_name == "SmolDocling-256M-preview":
315
- if "OTSL" in text or "code" in text:
316
- images = [add_random_padding(img) for img in images]
317
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
318
- text = normalize_values(text, target_max=500)
319
-
320
- messages = [
321
- {
322
- "role": "user",
323
- "content": [{"type": "image"} for _ in images] + [
324
- {"type": "text", "text": text}
325
- ]
326
- }
327
  ]
328
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
329
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
330
-
331
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
332
- generation_kwargs = {
333
- **inputs,
334
- "streamer": streamer,
335
- "max_new_tokens": max_new_tokens,
336
- "temperature": temperature,
337
- "top_p": top_p,
338
- "top_k": top_k,
339
- "repetition_penalty": repetition_penalty,
340
- }
341
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
342
- thread.start()
343
-
344
- buffer = ""
345
- full_output = ""
346
- for new_text in streamer:
347
- full_output += new_text
348
- buffer += new_text.replace("<|im_end|>", "")
349
- yield buffer
350
-
351
- if model_name == "SmolDocling-256M-preview":
352
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
353
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
354
- if "<chart>" in cleaned_output:
355
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
356
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
357
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
358
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
359
- markdown_output = doc.export_to_markdown()
360
- yield f"**MD Output:**\n\n{markdown_output}"
361
- else:
362
- yield cleaned_output
363
 
364
  @spaces.GPU
365
  def generate_video(model_name: str, text: str, video_path: str,
366
- max_new_tokens: int = 1024, temperature: float = 0.6,
367
- top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
368
- """Generate responses for video input using the selected model."""
369
- if model_name == "ByteDance-s-Dolphin":
370
- if video_path is None:
371
- yield "Please upload a video."
372
- return
373
- frames = downsample_video(video_path)
374
- markdown_contents = []
375
- for idx, (frame, _) in enumerate(frames):
376
- markdown_content = process_image_with_dolphin(frame)
377
- markdown_contents.append(f"**Frame {idx + 1}:**\n{markdown_content}")
378
- combined_markdown = "\n\n---\n\n".join(markdown_contents)
379
- yield combined_markdown
380
  else:
381
- if model_name == "olmOCR-7B-0225-preview":
382
- processor = processor_m
383
- model = model_m
384
- elif model_name == "SmolDocling-256M-preview":
385
- processor = processor_x
386
- model = model_x
387
- else:
388
- yield "Invalid model selected."
389
- return
390
-
391
- if video_path is None:
392
- yield "Please upload a video."
393
- return
394
-
395
- frames = downsample_video(video_path)
396
- images = [frame for frame, _ in frames]
397
-
398
- if model_name == "SmolDocling-256M-preview":
399
- if "OTSL" in text or "code" in text:
400
- images = [add_random_padding(img) for img in images]
401
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
402
- text = normalize_values(text, target_max=500)
403
-
404
- messages = [
405
- {
406
- "role": "user",
407
- "content": [{"type": "image"} for _ in images] + [
408
- {"type": "text", "text": text}
409
- ]
410
- }
411
- ]
412
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
413
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
414
-
415
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
416
- generation_kwargs = {
417
- **inputs,
418
- "streamer": streamer,
419
- "max_new_tokens": max_new_tokens,
420
- "temperature": temperature,
421
- "top_p": top_p,
422
- "top_k": top_k,
423
- "repetition_penalty": repetition_penalty,
424
- }
425
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
426
- thread.start()
427
-
428
- buffer = ""
429
- full_output = ""
430
- for new_text in streamer:
431
- full_output += new_text
432
- buffer += new_text.replace("<|im_end|>", "")
433
- yield buffer
434
-
435
- if model_name == "SmolDocling-256M-preview":
436
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
437
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
438
- if "<chart>" in cleaned_output:
439
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
440
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
441
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
442
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
443
- markdown_output = doc.export_to_markdown()
444
- yield f"**MD Output:**\n\n{markdown_output}"
445
- else:
446
- yield cleaned_output
447
-
448
- # Define examples
449
  image_examples = [
450
- ["Convert this page to docling", "images/1.png"],
451
- ["OCR the image", "images/2.jpg"],
452
- ["Convert this page to docling", "images/3.png"],
453
  ]
454
 
455
  video_examples = [
456
- ["Explain the ad in detail", "example/1.mp4"],
457
- ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
458
  ]
459
 
460
  css = """
@@ -467,23 +201,28 @@ css = """
467
  }
468
  """
469
 
470
- # Create Gradio Interface
471
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
472
- gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
473
- gr.Markdown("**Note:** For Dolphin model, the text query is ignored, and PDFs are processed by parsing the first page.")
474
  with gr.Row():
475
  with gr.Column():
476
  with gr.Tabs():
477
  with gr.TabItem("Image Inference"):
478
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
479
- image_upload = gr.Image(type="pil", label="Image or PDF")
480
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
481
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
 
 
482
  with gr.TabItem("Video Inference"):
483
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
484
  video_upload = gr.Video(label="Video")
485
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
486
- gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
 
 
 
487
  with gr.Accordion("Advanced options", open=False):
488
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
489
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
@@ -491,13 +230,19 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
491
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
492
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
493
  with gr.Column():
494
- output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
495
  model_choice = gr.Radio(
496
- choices=["olmOCR-7B-0225-preview", "SmolDocling-256M-preview", "ByteDance-s-Dolphin"],
497
  label="Select Model",
498
- value="olmOCR-7B-0225-preview"
499
  )
500
 
 
 
 
 
 
 
501
  image_submit.click(
502
  fn=generate_image,
503
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
10
  import spaces
11
  import torch
12
  import numpy as np
13
+ from PIL import Image
14
  import cv2
 
 
15
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
+ Qwen2_5_VLForConditionalGeneration,
 
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
22
  from transformers.image_utils import load_image
23
 
 
 
 
 
 
 
24
  # Constants for text generation
25
  MAX_MAX_NEW_TOKENS = 2048
26
  DEFAULT_MAX_NEW_TOKENS = 1024
 
28
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
 
31
+ # Load VIREX-062225-exp
32
+ MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
33
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
34
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ MODEL_ID_M,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16
38
+ ).to(device).eval()
39
+
40
+ # Load DREX-062225-exp
41
+ MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
42
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
43
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
44
+ MODEL_ID_X,
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16
47
+ ).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def downsample_video(video_path):
50
+ """
51
+ Downsamples the video to evenly spaced frames.
52
+ Each frame is returned as a PIL image along with its timestamp.
53
+ """
54
  vidcap = cv2.VideoCapture(video_path)
55
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
56
  fps = vidcap.get(cv2.CAP_PROP_FPS)
 
67
  vidcap.release()
68
  return frames
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @spaces.GPU
71
  def generate_image(model_name: str, text: str, image: Image.Image,
72
+ max_new_tokens: int = 1024,
73
+ temperature: float = 0.6,
74
+ top_p: float = 0.9,
75
+ top_k: int = 50,
76
+ repetition_penalty: float = 1.2):
77
+ """
78
+ Generates responses using the selected model for image input.
79
+ """
80
+ if model_name == "VIREX-062225-exp":
81
+ processor = processor_m
82
+ model = model_m
83
+ elif model_name == "DREX-062225-exp":
84
+ processor = processor_x
85
+ model = model_x
86
  else:
87
+ yield "Invalid model selected."
88
+ return
89
+
90
+ if image is None:
91
+ yield "Please upload an image."
92
+ return
93
+
94
+ messages = [{
95
+ "role": "user",
96
+ "content": [
97
+ {"type": "image", "image": image},
98
+ {"type": "text", "text": text},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ]
100
+ }]
101
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ inputs = processor(
103
+ text=[prompt_full],
104
+ images=[image],
105
+ return_tensors="pt",
106
+ padding=True,
107
+ truncation=False,
108
+ max_length=MAX_INPUT_TOKEN_LENGTH
109
+ ).to(device)
110
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
111
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
112
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
113
+ thread.start()
114
+ buffer = ""
115
+ for new_text in streamer:
116
+ buffer += new_text
117
+ buffer = buffer.replace("<|im_end|>", "")
118
+ time.sleep(0.01)
119
+ yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @spaces.GPU
122
  def generate_video(model_name: str, text: str, video_path: str,
123
+ max_new_tokens: int = 1024,
124
+ temperature: float = 0.6,
125
+ top_p: float = 0.9,
126
+ top_k: int = 50,
127
+ repetition_penalty: float = 1.2):
128
+ """
129
+ Generates responses using the selected model for video input.
130
+ """
131
+ if model_name == "VIREX-062225-exp":
132
+ processor = processor_m
133
+ model = model_m
134
+ elif model_name == "DREX-062225-exp":
135
+ processor = processor_x
136
+ model = model_x
137
  else:
138
+ yield "Invalid model selected."
139
+ return
140
+
141
+ if video_path is None:
142
+ yield "Please upload a video."
143
+ return
144
+
145
+ frames = downsample_video(video_path)
146
+ messages = [
147
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
148
+ {"role": "user", "content": [{"type": "text", "text": text}]}
149
+ ]
150
+ for frame in frames:
151
+ image, timestamp = frame
152
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
153
+ messages[1]["content"].append({"type": "image", "image": image})
154
+ inputs = processor.apply_chat_template(
155
+ messages,
156
+ tokenize=True,
157
+ add_generation_prompt=True,
158
+ return_dict=True,
159
+ return_tensors="pt",
160
+ truncation=False,
161
+ max_length=MAX_INPUT_TOKEN_LENGTH
162
+ ).to(device)
163
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
164
+ generation_kwargs = {
165
+ **inputs,
166
+ "streamer": streamer,
167
+ "max_new_tokens": max_new_tokens,
168
+ "do_sample": True,
169
+ "temperature": temperature,
170
+ "top_p": top_p,
171
+ "top_k": top_k,
172
+ "repetition_penalty": repetition_penalty,
173
+ }
174
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
+ thread.start()
176
+ buffer = ""
177
+ for new_text in streamer:
178
+ buffer += new_text
179
+ buffer = buffer.replace("<|im_end|>", "")
180
+ time.sleep(0.01)
181
+ yield buffer
182
+
183
+ # Define examples for image and video inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  image_examples = [
185
+ ["Perform OCR on the Image.", "images/1.jpg"],
186
+ ["Extract the table content", "images/2.png"]
 
187
  ]
188
 
189
  video_examples = [
190
+ ["Explain the Ad in Detail", "videos/1.mp4"],
191
+ ["Identify the main actions in the cartoon video", "videos/2.mp4"]
192
  ]
193
 
194
  css = """
 
201
  }
202
  """
203
 
204
+ # Create the Gradio Interface
205
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
206
+ gr.Markdown("# **Multimodal OCR**")
 
207
  with gr.Row():
208
  with gr.Column():
209
  with gr.Tabs():
210
  with gr.TabItem("Image Inference"):
211
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
212
+ image_upload = gr.Image(type="pil", label="Image")
213
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
214
+ gr.Examples(
215
+ examples=image_examples,
216
+ inputs=[image_query, image_upload]
217
+ )
218
  with gr.TabItem("Video Inference"):
219
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
220
  video_upload = gr.Video(label="Video")
221
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
222
+ gr.Examples(
223
+ examples=video_examples,
224
+ inputs=[video_query, video_upload]
225
+ )
226
  with gr.Accordion("Advanced options", open=False):
227
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
228
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
230
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
231
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
232
  with gr.Column():
233
+ output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
234
  model_choice = gr.Radio(
235
+ choices=["DREX-062225-exp", "VIREX-062225-exp"],
236
  label="Select Model",
237
+ value="VIREX-062225-exp"
238
  )
239
 
240
+ gr.Markdown("**Model Info 💻 | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs/discussions)**")
241
+ gr.Markdown("> [Qwen2-VL-OCR-2B-Instruct](https://huggingface.co/prithivMLmods/Qwen2-VL-OCR-2B-Instruct): qwen2-vl-ocr-2b-instruct model is a fine-tuned version of qwen2-vl-2b-instruct, tailored for tasks that involve [messy] optical character recognition (ocr), image-to-text conversion, and math problem solving with latex formatting.")
242
+ gr.Markdown("> [Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s): nanonets-ocr-s is a powerful, state-of-the-art image-to-markdown ocr model that goes far beyond traditional text extraction. it transforms documents into structured markdown with intelligent content recognition and semantic tagging.")
243
+ gr.Markdown("> [RolmOCR](https://huggingface.co/reducto/RolmOCR): rolmocr, high-quality, openly available approach to parsing pdfs and other complex documents oprical character recognition. it is designed to handle a wide range of document types, including scanned documents, handwritten text, and complex layouts.")
244
+ gr.Markdown("> [Aya-Vision](https://huggingface.co/CohereLabs/aya-vision-8b): cohere labs aya vision 8b is an open weights research release of an 8-billion parameter model with advanced capabilities optimized for a variety of vision-language use cases, including ocr, captioning, visual reasoning, summarization, question answering, code, and more.")
245
+
246
  image_submit.click(
247
  fn=generate_image,
248
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],