prithivMLmods commited on
Commit
7019b95
·
verified ·
1 Parent(s): 94d3a2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -62
app.py CHANGED
@@ -10,19 +10,24 @@ import gradio as gr
10
  import spaces
11
  import torch
12
  import numpy as np
13
- from PIL import Image
14
  import cv2
15
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
  AutoModelForVision2Seq,
20
- AutoModelForImageTextToText,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
  from transformers.image_utils import load_image
25
 
 
 
 
 
 
 
26
  # Constants for text generation
27
  MAX_MAX_NEW_TOKENS = 2048
28
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -48,30 +53,51 @@ model_x = AutoModelForVision2Seq.from_pretrained(
48
  torch_dtype=torch.float16
49
  ).to(device).eval()
50
 
51
- #--------------------------------------------------------------------------------------#
52
- #Load MonkeyOCR
53
  MODEL_ID_G = "echo840/MonkeyOCR"
54
  SUBFOLDER = "Recognition"
55
-
56
  processor_g = AutoProcessor.from_pretrained(
57
  MODEL_ID_G,
58
  trust_remote_code=True,
59
  subfolder=SUBFOLDER
60
  )
61
-
62
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
63
  MODEL_ID_G,
64
  trust_remote_code=True,
65
  subfolder=SUBFOLDER,
66
  torch_dtype=torch.float16
67
  ).to(device).eval()
68
- #--------------------------------------------------------------------------------------#
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def downsample_video(video_path):
71
- """
72
- Downsamples the video to evenly spaced frames.
73
- Each frame is returned as a PIL image along with its timestamp.
74
- """
75
  vidcap = cv2.VideoCapture(video_path)
76
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
77
  fps = vidcap.get(cv2.CAP_PROP_FPS)
@@ -95,18 +121,17 @@ def generate_image(model_name: str, text: str, image: Image.Image,
95
  top_p: float = 0.9,
96
  top_k: int = 50,
97
  repetition_penalty: float = 1.2):
98
- """
99
- Generates responses using the selected model for image input.
100
- """
101
  if model_name == "Nanonets-OCR-s":
102
  processor = processor_m
103
  model = model_m
104
- elif model_name == "SmolDocling-256M-preview":
105
- processor = processor_x
106
- model = model_x
107
  elif model_name == "MonkeyOCR-Recognition":
108
  processor = processor_g
109
  model = model_g
 
 
 
110
  else:
111
  yield "Invalid model selected."
112
  return
@@ -115,33 +140,64 @@ def generate_image(model_name: str, text: str, image: Image.Image,
115
  yield "Please upload an image."
116
  return
117
 
118
- messages = [{
119
- "role": "user",
120
- "content": [
121
- {"type": "image", "image": image},
122
- {"type": "text", "text": text},
123
- ]
124
- }]
125
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126
- inputs = processor(
127
- text=[prompt_full],
128
- images=[image],
129
- return_tensors="pt",
130
- padding=True,
131
- truncation=False,
132
- max_length=MAX_INPUT_TOKEN_LENGTH
133
- ).to(device)
 
 
 
 
 
 
 
134
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
135
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
 
 
 
 
 
 
 
 
136
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
137
  thread.start()
 
 
138
  buffer = ""
 
139
  for new_text in streamer:
140
- buffer += new_text
141
- buffer = buffer.replace("<|im_end|>", "")
142
- time.sleep(0.01)
143
  yield buffer
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @spaces.GPU
146
  def generate_video(model_name: str, text: str, video_path: str,
147
  max_new_tokens: int = 1024,
@@ -149,18 +205,17 @@ def generate_video(model_name: str, text: str, video_path: str,
149
  top_p: float = 0.9,
150
  top_k: int = 50,
151
  repetition_penalty: float = 1.2):
152
- """
153
- Generates responses using the selected model for video input.
154
- """
155
  if model_name == "Nanonets-OCR-s":
156
  processor = processor_m
157
  model = model_m
158
- elif model_name == "SmolDocling-256M-preview":
159
- processor = processor_x
160
- model = model_x
161
  elif model_name == "MonkeyOCR-Recognition":
162
  processor = processor_g
163
  model = model_g
 
 
 
164
  else:
165
  yield "Invalid model selected."
166
  return
@@ -169,30 +224,35 @@ def generate_video(model_name: str, text: str, video_path: str,
169
  yield "Please upload a video."
170
  return
171
 
 
172
  frames = downsample_video(video_path)
 
 
 
 
 
 
 
 
 
 
173
  messages = [
174
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
175
- {"role": "user", "content": [{"type": "text", "text": text}]}
 
 
 
 
176
  ]
177
- for frame in frames:
178
- image, timestamp = frame
179
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
180
- messages[1]["content"].append({"type": "image", "image": image})
181
- inputs = processor.apply_chat_template(
182
- messages,
183
- tokenize=True,
184
- add_generation_prompt=True,
185
- return_dict=True,
186
- return_tensors="pt",
187
- truncation=False,
188
- max_length=MAX_INPUT_TOKEN_LENGTH
189
- ).to(device)
190
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
191
  generation_kwargs = {
192
  **inputs,
193
  "streamer": streamer,
194
  "max_new_tokens": max_new_tokens,
195
- "do_sample": True,
196
  "temperature": temperature,
197
  "top_p": top_p,
198
  "top_k": top_k,
@@ -200,13 +260,29 @@ def generate_video(model_name: str, text: str, video_path: str,
200
  }
201
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
202
  thread.start()
 
 
203
  buffer = ""
 
204
  for new_text in streamer:
205
- buffer += new_text
206
- buffer = buffer.replace("<|im_end|>", "")
207
- time.sleep(0.01)
208
  yield buffer
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  # Define examples for image and video inference
211
  image_examples = [
212
  ["fill the correct numbers", "example/image3.png"],
 
10
  import spaces
11
  import torch
12
  import numpy as np
13
+ from PIL import Image, ImageOps
14
  import cv2
15
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
  AutoModelForVision2Seq,
 
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
 
25
+ from docling_core.types.doc import DoclingDocument, DocTagsDocument
26
+
27
+ import re
28
+ import ast
29
+ import html
30
+
31
  # Constants for text generation
32
  MAX_MAX_NEW_TOKENS = 2048
33
  DEFAULT_MAX_NEW_TOKENS = 1024
 
53
  torch_dtype=torch.float16
54
  ).to(device).eval()
55
 
56
+ # Load MonkeyOCR
 
57
  MODEL_ID_G = "echo840/MonkeyOCR"
58
  SUBFOLDER = "Recognition"
 
59
  processor_g = AutoProcessor.from_pretrained(
60
  MODEL_ID_G,
61
  trust_remote_code=True,
62
  subfolder=SUBFOLDER
63
  )
 
64
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
  MODEL_ID_G,
66
  trust_remote_code=True,
67
  subfolder=SUBFOLDER,
68
  torch_dtype=torch.float16
69
  ).to(device).eval()
70
+
71
+ # Preprocessing functions for SmolDocling-256M
72
+ def add_random_padding(image, min_percent=0.1, max_percent=0.10):
73
+ """Add random padding to an image based on its size."""
74
+ image = image.convert("RGB")
75
+ width, height = image.size
76
+ pad_w_percent = random.uniform(min_percent, max_percent)
77
+ pad_h_percent = random.uniform(min_percent, max_percent)
78
+ pad_w = int(width * pad_w_percent)
79
+ pad_h = int(height * pad_h_percent)
80
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
81
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
82
+ return padded_image
83
+
84
+ def normalize_values(text, target_max=500):
85
+ """Normalize numerical values in text to a target maximum."""
86
+ def normalize_list(values):
87
+ max_value = max(values) if values else 1
88
+ return [round((v / max_value) * target_max) for v in values]
89
+
90
+ def process_match(match):
91
+ num_list = ast.literal_eval(match.group(0))
92
+ normalized = normalize_list(num_list)
93
+ return "".join([f"<loc_{num}>" for num in normalized])
94
+
95
+ pattern = r"\[([\d\.\s,]+)\]"
96
+ normalized_text = re.sub(pattern, process_match, text)
97
+ return normalized_text
98
 
99
  def downsample_video(video_path):
100
+ """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
 
 
 
101
  vidcap = cv2.VideoCapture(video_path)
102
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
103
  fps = vidcap.get(cv2.CAP_PROP_FPS)
 
121
  top_p: float = 0.9,
122
  top_k: int = 50,
123
  repetition_penalty: float = 1.2):
124
+ """Generate responses for image input using the selected model."""
125
+ # Model selection
 
126
  if model_name == "Nanonets-OCR-s":
127
  processor = processor_m
128
  model = model_m
 
 
 
129
  elif model_name == "MonkeyOCR-Recognition":
130
  processor = processor_g
131
  model = model_g
132
+ elif model_name == "SmolDocling-256M-preview":
133
+ processor = processor_x
134
+ model = model_x
135
  else:
136
  yield "Invalid model selected."
137
  return
 
140
  yield "Please upload an image."
141
  return
142
 
143
+ # Prepare images as a list (single image for image inference)
144
+ images = [image]
145
+
146
+ # SmolDocling-256M specific preprocessing
147
+ if model_name == "SmolDocling-256M-preview":
148
+ if "OTSL" in text or "code" in text:
149
+ images = [add_random_padding(img) for img in images]
150
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
151
+ text = normalize_values(text, target_max=500)
152
+
153
+ # Unified message structure for all models
154
+ messages = [
155
+ {
156
+ "role": "user",
157
+ "content": [{"type": "image"} for _ in images] + [
158
+ {"type": "text", "text": text}
159
+ ]
160
+ }
161
+ ]
162
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
163
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
164
+
165
+ # Generation with streaming
166
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
167
+ generation_kwargs = {
168
+ **inputs,
169
+ "streamer": streamer,
170
+ "max_new_tokens": max_new_tokens,
171
+ "temperature": temperature,
172
+ "top_p": top_p,
173
+ "top_k": top_k,
174
+ "repetition_penalty": repetition_penalty,
175
+ }
176
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
177
  thread.start()
178
+
179
+ # Stream output and collect full response
180
  buffer = ""
181
+ full_output = ""
182
  for new_text in streamer:
183
+ full_output += new_text
184
+ buffer += new_text.replace("<|im_end|>", "")
 
185
  yield buffer
186
 
187
+ # SmolDocling-256M specific postprocessing
188
+ if model_name == "SmolDocling-256M-preview":
189
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
190
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
191
+ if "<chart>" in cleaned_output:
192
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
193
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
194
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
195
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
196
+ markdown_output = doc.export_to_markdown()
197
+ yield f"**MD Output:**\n\n{markdown_output}"
198
+ else:
199
+ yield cleaned_output
200
+
201
  @spaces.GPU
202
  def generate_video(model_name: str, text: str, video_path: str,
203
  max_new_tokens: int = 1024,
 
205
  top_p: float = 0.9,
206
  top_k: int = 50,
207
  repetition_penalty: float = 1.2):
208
+ """Generate responses for video input using the selected model."""
209
+ # Model selection
 
210
  if model_name == "Nanonets-OCR-s":
211
  processor = processor_m
212
  model = model_m
 
 
 
213
  elif model_name == "MonkeyOCR-Recognition":
214
  processor = processor_g
215
  model = model_g
216
+ elif model_name == "SmolDocling-256M-preview":
217
+ processor = processor_x
218
+ model = model_x
219
  else:
220
  yield "Invalid model selected."
221
  return
 
224
  yield "Please upload a video."
225
  return
226
 
227
+ # Extract frames from video
228
  frames = downsample_video(video_path)
229
+ images = [frame for frame, _ in frames]
230
+
231
+ # SmolDocling-256M specific preprocessing
232
+ if model_name == "SmolDocling-256M-preview":
233
+ if "OTSL" in text or "code" in text:
234
+ images = [add_random_padding(img) for img in images]
235
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
236
+ text = normalize_values(text, target_max=500)
237
+
238
+ # Unified message structure for all models
239
  messages = [
240
+ {
241
+ "role": "user",
242
+ "content": [{"type": "image"} for _ in images] + [
243
+ {"type": "text", "text": text}
244
+ ]
245
+ }
246
  ]
247
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
248
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
249
+
250
+ # Generation with streaming
 
 
 
 
 
 
 
 
 
251
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
252
  generation_kwargs = {
253
  **inputs,
254
  "streamer": streamer,
255
  "max_new_tokens": max_new_tokens,
 
256
  "temperature": temperature,
257
  "top_p": top_p,
258
  "top_k": top_k,
 
260
  }
261
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
262
  thread.start()
263
+
264
+ # Stream output and collect full response
265
  buffer = ""
266
+ full_output = ""
267
  for new_text in streamer:
268
+ full_output += new_text
269
+ buffer += new_text.replace("<|im_end|>", "")
 
270
  yield buffer
271
 
272
+ # SmolDocling-256M specific postprocessing
273
+ if model_name == "SmolDocling-256M-preview":
274
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
275
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
276
+ if "<chart>" in cleaned_output:
277
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
278
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
279
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
280
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
281
+ markdown_output = doc.export_to_markdown()
282
+ yield f"**MD Output:**\n\n{markdown_output}"
283
+ else:
284
+ yield cleaned_output
285
+
286
  # Define examples for image and video inference
287
  image_examples = [
288
  ["fill the correct numbers", "example/image3.png"],