prithivMLmods commited on
Commit
2f295a3
·
verified ·
1 Parent(s): fd30e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -69
app.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image, ImageOps
14
  import cv2
 
15
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
@@ -35,32 +36,40 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
- # Load olmOCR-7B-0225-preview
39
- MODEL_ID_M = "allenai/olmOCR-7B-0225-preview"
40
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
41
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
42
- MODEL_ID_M,
43
- trust_remote_code=True,
44
- torch_dtype=torch.float16
45
- ).to(device).eval()
46
-
47
- # Load ByteDance's Dolphin
48
- MODEL_ID_K = "ByteDance/Dolphin"
49
- processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
50
- model_k = VisionEncoderDecoderModel.from_pretrained(
51
- MODEL_ID_K,
52
- trust_remote_code=True,
53
- torch_dtype=torch.float16
54
- ).to(device).eval()
55
-
56
- # Load SmolDocling-256M-preview
57
- MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
58
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
59
- model_x = AutoModelForVision2Seq.from_pretrained(
60
- MODEL_ID_X,
61
- trust_remote_code=True,
62
- torch_dtype=torch.float16
63
- ).to(device).eval()
 
 
 
 
 
 
 
 
64
 
65
  # Preprocessing functions for SmolDocling-256M
66
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
@@ -71,7 +80,7 @@ def add_random_padding(image, min_percent=0.1, max_percent=0.10):
71
  pad_h_percent = random.uniform(min_percent, max_percent)
72
  pad_w = int(width * pad_w_percent)
73
  pad_h = int(height * pad_h_percent)
74
- corner_pixel = image.getpixel((0, 0)) # Top-left corner
75
  padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
76
  return padded_image
77
 
@@ -109,11 +118,12 @@ def downsample_video(video_path):
109
  return frames
110
 
111
  # Dolphin-specific functions
 
112
  def model_chat(prompt, image, is_batch=False):
113
  """Use Dolphin model for inference, supporting both single and batch processing."""
114
- processor = processor_k
115
- model = model_k
116
- device = "cuda" if torch.cuda.is_available() else "cpu"
117
 
118
  if not is_batch:
119
  images = [image]
@@ -122,33 +132,30 @@ def model_chat(prompt, image, is_batch=False):
122
  images = image
123
  prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
124
 
125
- inputs = processor(images, return_tensors="pt", padding=True).to(device)
126
  pixel_values = inputs.pixel_values.half()
127
 
128
  prompts = [f"<s>{p} <Answer/>" for p in prompts]
129
- prompt_inputs = processor.tokenizer(
130
- prompts,
131
- add_special_tokens=False, # Explicitly set to False
132
- return_tensors="pt",
133
- padding=True
134
  ).to(device)
135
 
136
- outputs = model.generate(
137
  pixel_values=pixel_values,
138
  decoder_input_ids=prompt_inputs.input_ids,
139
  decoder_attention_mask=prompt_inputs.attention_mask,
140
  min_length=1,
141
  max_length=4096,
142
- pad_token_id=processor.tokenizer.pad_token_id,
143
- eos_token_id=processor.tokenizer.eos_token_id,
144
  use_cache=True,
145
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
146
  return_dict_in_generate=True,
147
  do_sample=False,
148
  num_beams=1,
149
  repetition_penalty=1.1
150
  )
151
- sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
152
 
153
  results = []
154
  for i, sequence in enumerate(sequences):
@@ -157,6 +164,7 @@ def model_chat(prompt, image, is_batch=False):
157
 
158
  return results[0] if not is_batch else results
159
 
 
160
  def process_element_batch(elements, prompt, max_batch_size=16):
161
  """Process a batch of elements with the same prompt."""
162
  results = []
@@ -244,24 +252,41 @@ def generate_markdown(recognition_results):
244
  markdown += f"{element['text']}\n\n"
245
  return markdown.strip()
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def process_image_with_dolphin(image):
248
  """Process a single image with Dolphin model."""
249
- layout_output = model_chat("Parse the reading order of this document.", image)
250
- elements = process_elements(layout_output, image)
 
 
 
251
  markdown_content = generate_markdown(elements)
252
  return markdown_content
253
 
254
  @spaces.GPU
255
  def generate_image(model_name: str, text: str, image: Image.Image,
256
- max_new_tokens: int = 1024,
257
- temperature: float = 0.6,
258
- top_p: float = 0.9,
259
- top_k: int = 50,
260
- repetition_penalty: float = 1.2):
261
  """Generate responses for image input using the selected model."""
262
  if model_name == "ByteDance-s-Dolphin":
263
  if image is None:
264
- yield "Please upload an image."
265
  return
266
  markdown_content = process_image_with_dolphin(image)
267
  yield markdown_content
@@ -280,7 +305,10 @@ def generate_image(model_name: str, text: str, image: Image.Image,
280
  yield "Please upload an image."
281
  return
282
 
283
- images = [image]
 
 
 
284
 
285
  if model_name == "SmolDocling-256M-preview":
286
  if "OTSL" in text or "code" in text:
@@ -334,11 +362,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
334
 
335
  @spaces.GPU
336
  def generate_video(model_name: str, text: str, video_path: str,
337
- max_new_tokens: int = 1024,
338
- temperature: float = 0.6,
339
- top_p: float = 0.9,
340
- top_k: int = 50,
341
- repetition_penalty: float = 1.2):
342
  """Generate responses for video input using the selected model."""
343
  if model_name == "ByteDance-s-Dolphin":
344
  if video_path is None:
@@ -346,10 +371,10 @@ def generate_video(model_name: str, text: str, video_path: str,
346
  return
347
  frames = downsample_video(video_path)
348
  markdown_contents = []
349
- for frame, _ in frames:
350
  markdown_content = process_image_with_dolphin(frame)
351
- markdown_contents.append(markdown_content)
352
- combined_markdown = "\n\n".join(markdown_contents)
353
  yield combined_markdown
354
  else:
355
  if model_name == "olmOCR-7B-0225-preview":
@@ -419,7 +444,7 @@ def generate_video(model_name: str, text: str, video_path: str,
419
  else:
420
  yield cleaned_output
421
 
422
- # Define examples for image and video inference
423
  image_examples = [
424
  ["Convert this page to docling", "images/1.png"],
425
  ["OCR the image", "images/2.jpg"],
@@ -441,28 +466,23 @@ css = """
441
  }
442
  """
443
 
444
- # Create the Gradio Interface
445
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
446
  gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
 
447
  with gr.Row():
448
  with gr.Column():
449
  with gr.Tabs():
450
  with gr.TabItem("Image Inference"):
451
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
452
- image_upload = gr.Image(type="pil", label="Image")
453
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
454
- gr.Examples(
455
- examples=image_examples,
456
- inputs=[image_query, image_upload]
457
- )
458
  with gr.TabItem("Video Inference"):
459
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
460
  video_upload = gr.Video(label="Video")
461
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
462
- gr.Examples(
463
- examples=video_examples,
464
- inputs=[video_query, video_upload]
465
- )
466
  with gr.Accordion("Advanced options", open=False):
467
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
468
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
12
  import numpy as np
13
  from PIL import Image, ImageOps
14
  import cv2
15
+ import pymupdf
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
 
36
 
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
 
39
+ # Global variables for Dolphin model
40
+ model_k = None
41
+ processor_k = None
42
+ tokenizer_k = None
43
+
44
+ # Load models
45
+ def initialize_models():
46
+ global model_k, processor_k, tokenizer_k
47
+ # Load olmOCR-7B-0225-preview
48
+ MODEL_ID_M = "allenai/olmOCR-7B-0225-preview"
49
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
50
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
52
+ ).to(device).eval()
53
+
54
+ # Load ByteDance's Dolphin
55
+ MODEL_ID_K = "ByteDance/Dolphin"
56
+ processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
57
+ if model_k is None:
58
+ model_k = VisionEncoderDecoderModel.from_pretrained(
59
+ MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16
60
+ ).to(device).eval()
61
+ tokenizer_k = processor_k.tokenizer
62
+
63
+ # Load SmolDocling-256M-preview
64
+ MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
65
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
66
+ model_x = AutoModelForVision2Seq.from_pretrained(
67
+ MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16
68
+ ).to(device).eval()
69
+
70
+ return processor_m, model_m, processor_x, model_x
71
+
72
+ processor_m, model_m, processor_x, model_x = initialize_models()
73
 
74
  # Preprocessing functions for SmolDocling-256M
75
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
 
80
  pad_h_percent = random.uniform(min_percent, max_percent)
81
  pad_w = int(width * pad_w_percent)
82
  pad_h = int(height * pad_h_percent)
83
+ corner_pixel = image.getpixel((0, 0))
84
  padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
85
  return padded_image
86
 
 
118
  return frames
119
 
120
  # Dolphin-specific functions
121
+ @spaces.GPU
122
  def model_chat(prompt, image, is_batch=False):
123
  """Use Dolphin model for inference, supporting both single and batch processing."""
124
+ global model_k, processor_k, tokenizer_k
125
+ if model_k is None:
126
+ initialize_models()
127
 
128
  if not is_batch:
129
  images = [image]
 
132
  images = image
133
  prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
134
 
135
+ inputs = processor_k(images, return_tensors="pt", padding=True).to(device)
136
  pixel_values = inputs.pixel_values.half()
137
 
138
  prompts = [f"<s>{p} <Answer/>" for p in prompts]
139
+ prompt_inputs = tokenizer_k(
140
+ prompts, add_special_tokens=False, return_tensors="pt", padding=True
 
 
 
141
  ).to(device)
142
 
143
+ outputs = model_k.generate(
144
  pixel_values=pixel_values,
145
  decoder_input_ids=prompt_inputs.input_ids,
146
  decoder_attention_mask=prompt_inputs.attention_mask,
147
  min_length=1,
148
  max_length=4096,
149
+ pad_token_id=tokenizer_k.pad_token_id,
150
+ eos_token_id=tokenizer_k.eos_token_id,
151
  use_cache=True,
152
+ bad_words_ids=[[tokenizer_k.unk_token_id]],
153
  return_dict_in_generate=True,
154
  do_sample=False,
155
  num_beams=1,
156
  repetition_penalty=1.1
157
  )
158
+ sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False)
159
 
160
  results = []
161
  for i, sequence in enumerate(sequences):
 
164
 
165
  return results[0] if not is_batch else results
166
 
167
+ @spaces.GPU
168
  def process_element_batch(elements, prompt, max_batch_size=16):
169
  """Process a batch of elements with the same prompt."""
170
  results = []
 
252
  markdown += f"{element['text']}\n\n"
253
  return markdown.strip()
254
 
255
+ def convert_to_image(image):
256
+ """Convert uploaded file to PIL Image, handling PDFs by extracting the first page."""
257
+ if isinstance(image, str): # File path from Gradio
258
+ if image.lower().endswith('.pdf'):
259
+ doc = pymupdf.open(image)
260
+ page = doc[0]
261
+ pix = page.get_pixmap()
262
+ img_data = pix.tobytes("png")
263
+ pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
264
+ doc.close()
265
+ return pil_image
266
+ else:
267
+ return Image.open(image).convert("RGB")
268
+ elif isinstance(image, Image.Image): # Already a PIL Image
269
+ return image.convert("RGB")
270
+ return None
271
+
272
  def process_image_with_dolphin(image):
273
  """Process a single image with Dolphin model."""
274
+ pil_image = convert_to_image(image)
275
+ if pil_image is None:
276
+ return "Error: Unable to process the uploaded file."
277
+ layout_output = model_chat("Parse the reading order of this document.", pil_image)
278
+ elements = process_elements(layout_output, pil_image)
279
  markdown_content = generate_markdown(elements)
280
  return markdown_content
281
 
282
  @spaces.GPU
283
  def generate_image(model_name: str, text: str, image: Image.Image,
284
+ max_new_tokens: int = 1024, temperature: float = 0.6,
285
+ top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
286
  """Generate responses for image input using the selected model."""
287
  if model_name == "ByteDance-s-Dolphin":
288
  if image is None:
289
+ yield "Please upload an image or PDF (first page will be processed)."
290
  return
291
  markdown_content = process_image_with_dolphin(image)
292
  yield markdown_content
 
305
  yield "Please upload an image."
306
  return
307
 
308
+ images = [convert_to_image(image)]
309
+ if images[0] is None:
310
+ yield "Error: Unable to process the uploaded file."
311
+ return
312
 
313
  if model_name == "SmolDocling-256M-preview":
314
  if "OTSL" in text or "code" in text:
 
362
 
363
  @spaces.GPU
364
  def generate_video(model_name: str, text: str, video_path: str,
365
+ max_new_tokens: int = 1024, temperature: float = 0.6,
366
+ top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
367
  """Generate responses for video input using the selected model."""
368
  if model_name == "ByteDance-s-Dolphin":
369
  if video_path is None:
 
371
  return
372
  frames = downsample_video(video_path)
373
  markdown_contents = []
374
+ for idx, (frame, _) in enumerate(frames):
375
  markdown_content = process_image_with_dolphin(frame)
376
+ markdown_contents.append(f"**Frame {idx + 1}:**\n{markdown_content}")
377
+ combined_markdown = "\n\n---\n\n".join(markdown_contents)
378
  yield combined_markdown
379
  else:
380
  if model_name == "olmOCR-7B-0225-preview":
 
444
  else:
445
  yield cleaned_output
446
 
447
+ # Define examples
448
  image_examples = [
449
  ["Convert this page to docling", "images/1.png"],
450
  ["OCR the image", "images/2.jpg"],
 
466
  }
467
  """
468
 
469
+ # Create Gradio Interface
470
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
471
  gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
472
+ gr.Markdown("**Note:** For Dolphin model, the text query is ignored, and PDFs are processed by parsing the first page.")
473
  with gr.Row():
474
  with gr.Column():
475
  with gr.Tabs():
476
  with gr.TabItem("Image Inference"):
477
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
478
+ image_upload = gr.Image(type="pil", label="Image or PDF")
479
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
480
+ gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
 
 
481
  with gr.TabItem("Video Inference"):
482
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
483
  video_upload = gr.Video(label="Video")
484
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
485
+ gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
 
 
 
486
  with gr.Accordion("Advanced options", open=False):
487
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
488
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)