Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -62,7 +62,6 @@ model_x = AutoModelForVision2Seq.from_pretrained(
|
|
62 |
torch_dtype=torch.float16
|
63 |
).to(device).eval()
|
64 |
|
65 |
-
|
66 |
# Preprocessing functions for SmolDocling-256M
|
67 |
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
|
68 |
"""Add random padding to an image based on its size."""
|
@@ -110,18 +109,30 @@ def downsample_video(video_path):
|
|
110 |
return frames
|
111 |
|
112 |
# Dolphin-specific functions
|
113 |
-
def model_chat(prompt, image):
|
114 |
-
"""Use Dolphin model for inference."""
|
115 |
processor = processor_k
|
116 |
model = model_k
|
117 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
pixel_values = inputs.pixel_values.half()
|
|
|
|
|
120 |
prompt_inputs = processor.tokenizer(
|
121 |
-
|
122 |
add_special_tokens=False,
|
123 |
-
return_tensors="pt"
|
|
|
124 |
).to(device)
|
|
|
125 |
outputs = model.generate(
|
126 |
pixel_values=pixel_values,
|
127 |
decoder_input_ids=prompt_inputs.input_ids,
|
@@ -137,20 +148,48 @@ def model_chat(prompt, image):
|
|
137 |
num_beams=1,
|
138 |
repetition_penalty=1.1
|
139 |
)
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
def process_elements(layout_results, image):
|
145 |
"""Parse layout results and extract elements from the image."""
|
146 |
-
# Placeholder parsing logic based on expected Dolphin output
|
147 |
-
# Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
|
148 |
try:
|
149 |
elements = ast.literal_eval(layout_results)
|
150 |
except:
|
151 |
-
elements = []
|
152 |
|
153 |
-
|
|
|
|
|
154 |
reading_order = 0
|
155 |
|
156 |
for bbox, label in elements:
|
@@ -158,27 +197,21 @@ def process_elements(layout_results, image):
|
|
158 |
x1, y1, x2, y2 = map(int, bbox)
|
159 |
cropped = image.crop((x1, y1, x2, y2))
|
160 |
if cropped.size[0] > 0 and cropped.size[1] > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
if label == "text":
|
162 |
-
|
163 |
-
recognition_results.append({
|
164 |
-
"label": label,
|
165 |
-
"bbox": [x1, y1, x2, y2],
|
166 |
-
"text": text.strip(),
|
167 |
-
"reading_order": reading_order
|
168 |
-
})
|
169 |
elif label == "table":
|
170 |
-
|
171 |
-
recognition_results.append({
|
172 |
-
"label": label,
|
173 |
-
"bbox": [x1, y1, x2, y2],
|
174 |
-
"text": table_text.strip(),
|
175 |
-
"reading_order": reading_order
|
176 |
-
})
|
177 |
elif label == "figure":
|
178 |
-
|
179 |
"label": label,
|
180 |
"bbox": [x1, y1, x2, y2],
|
181 |
-
"text": "[Figure]",
|
182 |
"reading_order": reading_order
|
183 |
})
|
184 |
reading_order += 1
|
@@ -186,12 +219,23 @@ def process_elements(layout_results, image):
|
|
186 |
print(f"Error processing element: {e}")
|
187 |
continue
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
return recognition_results
|
190 |
|
191 |
def generate_markdown(recognition_results):
|
192 |
"""Generate markdown from extracted elements."""
|
193 |
markdown = ""
|
194 |
-
for element in
|
195 |
if element["label"] == "text":
|
196 |
markdown += f"{element['text']}\n\n"
|
197 |
elif element["label"] == "table":
|
@@ -222,7 +266,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
|
|
222 |
markdown_content = process_image_with_dolphin(image)
|
223 |
yield markdown_content
|
224 |
else:
|
225 |
-
# Existing logic for other models
|
226 |
if model_name == "olmOCR-7B-0225-preview":
|
227 |
processor = processor_m
|
228 |
model = model_m
|
@@ -309,7 +352,6 @@ def generate_video(model_name: str, text: str, video_path: str,
|
|
309 |
combined_markdown = "\n\n".join(markdown_contents)
|
310 |
yield combined_markdown
|
311 |
else:
|
312 |
-
# Existing logic for other models
|
313 |
if model_name == "olmOCR-7B-0225-preview":
|
314 |
processor = processor_m
|
315 |
model = model_m
|
@@ -401,7 +443,7 @@ css = """
|
|
401 |
|
402 |
# Create the Gradio Interface
|
403 |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
404 |
-
gr.Markdown("# **[
|
405 |
with gr.Row():
|
406 |
with gr.Column():
|
407 |
with gr.Tabs():
|
|
|
62 |
torch_dtype=torch.float16
|
63 |
).to(device).eval()
|
64 |
|
|
|
65 |
# Preprocessing functions for SmolDocling-256M
|
66 |
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
|
67 |
"""Add random padding to an image based on its size."""
|
|
|
109 |
return frames
|
110 |
|
111 |
# Dolphin-specific functions
|
112 |
+
def model_chat(prompt, image, is_batch=False):
|
113 |
+
"""Use Dolphin model for inference, supporting both single and batch processing."""
|
114 |
processor = processor_k
|
115 |
model = model_k
|
116 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
117 |
+
|
118 |
+
if not is_batch:
|
119 |
+
images = [image]
|
120 |
+
prompts = [prompt]
|
121 |
+
else:
|
122 |
+
images = image
|
123 |
+
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
|
124 |
+
|
125 |
+
inputs = processor(images, return_tensors="pt", padding=True).to(device)
|
126 |
pixel_values = inputs.pixel_values.half()
|
127 |
+
|
128 |
+
prompts = [f"<s>{p} <Answer/>" for p in prompts]
|
129 |
prompt_inputs = processor.tokenizer(
|
130 |
+
prompts,
|
131 |
add_special_tokens=False,
|
132 |
+
return_tensors="pt",
|
133 |
+
padding=True
|
134 |
).to(device)
|
135 |
+
|
136 |
outputs = model.generate(
|
137 |
pixel_values=pixel_values,
|
138 |
decoder_input_ids=prompt_inputs.input_ids,
|
|
|
148 |
num_beams=1,
|
149 |
repetition_penalty=1.1
|
150 |
)
|
151 |
+
sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
152 |
+
|
153 |
+
results = []
|
154 |
+
for i, sequence in enumerate(sequences):
|
155 |
+
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
|
156 |
+
results.append(cleaned)
|
157 |
+
|
158 |
+
return results[0] if not is_batch else results
|
159 |
+
|
160 |
+
def process_element_batch(elements, prompt, max_batch_size=16):
|
161 |
+
"""Process a batch of elements with the same prompt."""
|
162 |
+
results = []
|
163 |
+
batch_size = min(len(elements), max_batch_size)
|
164 |
+
|
165 |
+
for i in range(0, len(elements), batch_size):
|
166 |
+
batch_elements = elements[i:i + batch_size]
|
167 |
+
crops_list = [elem["crop"] for elem in batch_elements]
|
168 |
+
prompts_list = [prompt] * len(crops_list)
|
169 |
+
|
170 |
+
batch_results = model_chat(prompts_list, crops_list, is_batch=True)
|
171 |
+
|
172 |
+
for j, result in enumerate(batch_results):
|
173 |
+
elem = batch_elements[j]
|
174 |
+
results.append({
|
175 |
+
"label": elem["label"],
|
176 |
+
"bbox": elem["bbox"],
|
177 |
+
"text": result.strip(),
|
178 |
+
"reading_order": elem["reading_order"],
|
179 |
+
})
|
180 |
+
|
181 |
+
return results
|
182 |
|
183 |
def process_elements(layout_results, image):
|
184 |
"""Parse layout results and extract elements from the image."""
|
|
|
|
|
185 |
try:
|
186 |
elements = ast.literal_eval(layout_results)
|
187 |
except:
|
188 |
+
elements = []
|
189 |
|
190 |
+
text_elements = []
|
191 |
+
table_elements = []
|
192 |
+
figure_results = []
|
193 |
reading_order = 0
|
194 |
|
195 |
for bbox, label in elements:
|
|
|
197 |
x1, y1, x2, y2 = map(int, bbox)
|
198 |
cropped = image.crop((x1, y1, x2, y2))
|
199 |
if cropped.size[0] > 0 and cropped.size[1] > 0:
|
200 |
+
element_info = {
|
201 |
+
"crop": cropped,
|
202 |
+
"label": label,
|
203 |
+
"bbox": [x1, y1, x2, y2],
|
204 |
+
"reading_order": reading_order,
|
205 |
+
}
|
206 |
if label == "text":
|
207 |
+
text_elements.append(element_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
elif label == "table":
|
209 |
+
table_elements.append(element_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
elif label == "figure":
|
211 |
+
figure_results.append({
|
212 |
"label": label,
|
213 |
"bbox": [x1, y1, x2, y2],
|
214 |
+
"text": "[Figure]",
|
215 |
"reading_order": reading_order
|
216 |
})
|
217 |
reading_order += 1
|
|
|
219 |
print(f"Error processing element: {e}")
|
220 |
continue
|
221 |
|
222 |
+
recognition_results = figure_results.copy()
|
223 |
+
|
224 |
+
if text_elements:
|
225 |
+
text_results = process_element_batch(text_elements, "Read text in the image.")
|
226 |
+
recognition_results.extend(text_results)
|
227 |
+
|
228 |
+
if table_elements:
|
229 |
+
table_results = process_element_batch(table_elements, "Parse the table in the image.")
|
230 |
+
recognition_results.extend(table_results)
|
231 |
+
|
232 |
+
recognition_results.sort(key=lambda x: x["reading_order"])
|
233 |
return recognition_results
|
234 |
|
235 |
def generate_markdown(recognition_results):
|
236 |
"""Generate markdown from extracted elements."""
|
237 |
markdown = ""
|
238 |
+
for element in recognition_results:
|
239 |
if element["label"] == "text":
|
240 |
markdown += f"{element['text']}\n\n"
|
241 |
elif element["label"] == "table":
|
|
|
266 |
markdown_content = process_image_with_dolphin(image)
|
267 |
yield markdown_content
|
268 |
else:
|
|
|
269 |
if model_name == "olmOCR-7B-0225-preview":
|
270 |
processor = processor_m
|
271 |
model = model_m
|
|
|
352 |
combined_markdown = "\n\n".join(markdown_contents)
|
353 |
yield combined_markdown
|
354 |
else:
|
|
|
355 |
if model_name == "olmOCR-7B-0225-preview":
|
356 |
processor = processor_m
|
357 |
model = model_m
|
|
|
443 |
|
444 |
# Create the Gradio Interface
|
445 |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
446 |
+
gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
|
447 |
with gr.Row():
|
448 |
with gr.Column():
|
449 |
with gr.Tabs():
|