prithivMLmods commited on
Commit
7d955b3
·
verified ·
1 Parent(s): 5001e8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -301
app.py CHANGED
@@ -1,302 +1,398 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
- import gradio as gr
10
- import spaces
11
- import torch
12
- import numpy as np
13
- from PIL import Image
14
- import cv2
15
-
16
- from transformers import (
17
- Qwen2VLForConditionalGeneration,
18
- Qwen2_5_VLForConditionalGeneration,
19
- AutoModelForImageTextToText,
20
- AutoProcessor,
21
- TextIteratorStreamer,
22
- )
23
- from transformers.image_utils import load_image
24
-
25
- # Constants for text generation
26
- MAX_MAX_NEW_TOKENS = 2048
27
- DEFAULT_MAX_NEW_TOKENS = 1024
28
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
-
30
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
-
32
- # Load VIREX-062225-exp
33
- MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
34
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
35
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
36
- MODEL_ID_M,
37
- trust_remote_code=True,
38
- torch_dtype=torch.float16
39
- ).to(device).eval()
40
-
41
- # Load DREX-062225-exp
42
- MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
43
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
44
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
- MODEL_ID_X,
46
- trust_remote_code=True,
47
- torch_dtype=torch.float16
48
- ).to(device).eval()
49
-
50
- # Load typhoon-ocr-3b
51
- MODEL_ID_T = "scb10x/typhoon-ocr-3b"
52
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
53
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
- MODEL_ID_T,
55
- trust_remote_code=True,
56
- torch_dtype=torch.float16
57
- ).to(device).eval()
58
-
59
- # Load olmOCR-7B-0225-preview
60
- MODEL_ID_O = "allenai/olmOCR-7B-0225-preview"
61
- processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True)
62
- model_o = Qwen2VLForConditionalGeneration.from_pretrained(
63
- MODEL_ID_O,
64
- trust_remote_code=True,
65
- torch_dtype=torch.float16
66
- ).to(device).eval()
67
-
68
- def downsample_video(video_path):
69
- """
70
- Downsamples the video to evenly spaced frames.
71
- Each frame is returned as a PIL image along with its timestamp.
72
- """
73
- vidcap = cv2.VideoCapture(video_path)
74
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
75
- fps = vidcap.get(cv2.CAP_PROP_FPS)
76
- frames = []
77
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
78
- for i in frame_indices:
79
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
80
- success, image = vidcap.read()
81
- if success:
82
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
83
- pil_image = Image.fromarray(image)
84
- timestamp = round(i / fps, 2)
85
- frames.append((pil_image, timestamp))
86
- vidcap.release()
87
- return frames
88
-
89
- @spaces.GPU
90
- def generate_image(model_name: str, text: str, image: Image.Image,
91
- max_new_tokens: int = 1024,
92
- temperature: float = 0.6,
93
- top_p: float = 0.9,
94
- top_k: int = 50,
95
- repetition_penalty: float = 1.2):
96
- """
97
- Generates responses using the selected model for image input.
98
- """
99
- if model_name == "VIREX-062225-7B-exp":
100
- processor = processor_m
101
- model = model_m
102
- elif model_name == "DREX-062225-7B-exp":
103
- processor = processor_x
104
- model = model_x
105
- elif model_name == "olmOCR-7B-0225-preview":
106
- processor = processor_o
107
- model = model_o
108
- elif model_name == "Typhoon-OCR-3B":
109
- processor = processor_t
110
- model = model_t
111
- else:
112
- yield "Invalid model selected.", "Invalid model selected."
113
- return
114
-
115
- if image is None:
116
- yield "Please upload an image.", "Please upload an image."
117
- return
118
-
119
- messages = [{
120
- "role": "user",
121
- "content": [
122
- {"type": "image", "image": image},
123
- {"type": "text", "text": text},
124
- ]
125
- }]
126
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
127
- inputs = processor(
128
- text=[prompt_full],
129
- images=[image],
130
- return_tensors="pt",
131
- padding=True,
132
- truncation=False,
133
- max_length=MAX_INPUT_TOKEN_LENGTH
134
- ).to(device)
135
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
136
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
137
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
138
- thread.start()
139
- buffer = ""
140
- for new_text in streamer:
141
- buffer += new_text
142
- time.sleep(0.01)
143
- yield buffer, buffer
144
-
145
- @spaces.GPU
146
- def generate_video(model_name: str, text: str, video_path: str,
147
- max_new_tokens: int = 1024,
148
- temperature: float = 0.6,
149
- top_p: float = 0.9,
150
- top_k: int = 50,
151
- repetition_penalty: float = 1.2):
152
- """
153
- Generates responses using the selected model for video input.
154
- """
155
- if model_name == "VIREX-062225-7B-exp":
156
- processor = processor_m
157
- model = model_m
158
- elif model_name == "DREX-062225-7B-exp":
159
- processor = processor_x
160
- model = model_x
161
- elif model_name == "olmOCR-7B-0225-preview":
162
- processor = processor_o
163
- model = model_o
164
- elif model_name == "Typhoon-OCR-3B":
165
- processor = processor_t
166
- model = model_t
167
- else:
168
- yield "Invalid model selected.", "Invalid model selected."
169
- return
170
-
171
- if video_path is None:
172
- yield "Please upload a video.", "Please upload a video."
173
- return
174
-
175
- frames = downsample_video(video_path)
176
- messages = [
177
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
178
- {"role": "user", "content": [{"type": "text", "text": text}]}
179
- ]
180
- for frame in frames:
181
- image, timestamp = frame
182
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
183
- messages[1]["content"].append({"type": "image", "image": image})
184
- inputs = processor.apply_chat_template(
185
- messages,
186
- tokenize=True,
187
- add_generation_prompt=True,
188
- return_dict=True,
189
- return_tensors="pt",
190
- truncation=False,
191
- max_length=MAX_INPUT_TOKEN_LENGTH
192
- ).to(device)
193
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
194
- generation_kwargs = {
195
- **inputs,
196
- "streamer": streamer,
197
- "max_new_tokens": max_new_tokens,
198
- "do_sample": True,
199
- "temperature": temperature,
200
- "top_p": top_p,
201
- "top_k": top_k,
202
- "repetition_penalty": repetition_penalty,
203
- }
204
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
205
- thread.start()
206
- buffer = ""
207
- for new_text in streamer:
208
- buffer += new_text
209
- buffer = buffer.replace("<|im_end|>", "")
210
- time.sleep(0.01)
211
- yield buffer, buffer
212
-
213
- # Define examples for image and video inference
214
- image_examples = [
215
- ["Convert this page to doc [text] precisely.", "images/3.png"],
216
- ["Convert this page to doc [text] precisely.", "images/4.png"],
217
- ["Convert this page to doc [text] precisely.", "images/1.png"],
218
- ["Convert chart to OTSL.", "images/2.png"]
219
- ]
220
-
221
- video_examples = [
222
- ["Explain the video in detail.", "videos/2.mp4"],
223
- ["Explain the ad in detail.", "videos/1.mp4"]
224
- ]
225
-
226
- # Added CSS to style the output area as a "Canvas"
227
- css = """
228
- .submit-btn {
229
- background-color: #2980b9 !important;
230
- color: white !important;
231
- }
232
- .submit-btn:hover {
233
- background-color: #3498db !important;
234
- }
235
- .canvas-output {
236
- border: 2px solid #4682B4;
237
- border-radius: 10px;
238
- padding: 20px;
239
- }
240
- """
241
-
242
- # Create the Gradio Interface
243
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
244
- gr.Markdown("# **[Doc VLMs OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
245
- with gr.Row():
246
- with gr.Column():
247
- with gr.Tabs():
248
- with gr.TabItem("Image Inference"):
249
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
250
- image_upload = gr.Image(type="pil", label="Image")
251
- image_submit = gr.Button("Submit", elem_classes="submit-btn")
252
- gr.Examples(
253
- examples=image_examples,
254
- inputs=[image_query, image_upload]
255
- )
256
- with gr.TabItem("Video Inference"):
257
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
258
- video_upload = gr.Video(label="Video")
259
- video_submit = gr.Button("Submit", elem_classes="submit-btn")
260
- gr.Examples(
261
- examples=video_examples,
262
- inputs=[video_query, video_upload]
263
- )
264
- with gr.Accordion("Advanced options", open=False):
265
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
266
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
267
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
268
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
269
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
270
-
271
- with gr.Column():
272
- with gr.Column(elem_classes="canvas-output"):
273
- gr.Markdown("## Result Canvas")
274
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
275
- markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
276
-
277
- model_choice = gr.Radio(
278
- choices=["DREX-062225-7B-exp", "olmOCR-7B-0225-preview", "VIREX-062225-7B-exp", "Typhoon-OCR-3B"],
279
- label="Select Model",
280
- value="DREX-062225-7B-exp"
281
- )
282
-
283
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs/discussions)")
284
- gr.Markdown("> [DREX-062225-7B-exp](https://huggingface.co/prithivMLmods/DREX-062225-exp): the drex-062225-exp (document retrieval and extraction expert) model is a specialized fine-tuned version of docscopeocr-7b-050425-exp, optimized for document retrieval, content extraction, and analysis recognition. built on top of the qwen2.5-vl architecture.")
285
- gr.Markdown("> [VIREX-062225-7B-exp](https://huggingface.co/prithivMLmods/VIREX-062225-exp): the virex-062225-exp (video information retrieval and extraction expert - experimental) model is a fine-tuned version of qwen2.5-vl-7b-instruct, specifically optimized for advanced video understanding, image comprehension, sense of reasoning, and natural language decision-making through cot reasoning.")
286
- gr.Markdown("> [Typhoon-OCR-3B](https://huggingface.co/scb10x/typhoon-ocr-3b): a bilingual document parsing model built specifically for real-world documents in thai and english, inspired by models like olmocr, based on qwen2.5-vl-instruction. this model is intended to be used with a specific prompt only.")
287
- gr.Markdown("> [olmOCR-7B-0225](https://huggingface.co/allenai/olmOCR-7B-0225-preview): the olmocr-7b-0225-preview model is based on qwen2-vl-7b, optimized for document-level optical character recognition (ocr), long-context vision-language understanding, and accurate image-to-text conversion with mathematical latex formatting. designed with a focus on high-fidelity visual-textual comprehension.")
288
- gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
289
-
290
- image_submit.click(
291
- fn=generate_image,
292
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
293
- outputs=[output, markdown_output]
294
- )
295
- video_submit.click(
296
- fn=generate_video,
297
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
298
- outputs=[output, markdown_output]
299
- )
300
-
301
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image, ImageDraw
14
+ import cv2
15
+ import re
16
+
17
+ from transformers import (
18
+ Qwen2_5_VLForConditionalGeneration,
19
+ AutoProcessor,
20
+ TextIteratorStreamer,
21
+ )
22
+ from transformers.image_utils import load_image
23
+
24
+ # Constants for text generation
25
+ MAX_MAX_NEW_TOKENS = 2048
26
+ DEFAULT_MAX_NEW_TOKENS = 1024
27
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
28
+
29
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Load Camel-Doc-OCR-062825
32
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
33
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
34
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ MODEL_ID_M,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16
38
+ ).to(device).eval()
39
+
40
+ # Load Qwen2.5-VL-7B-Instruct
41
+ MODEL_ID_X = "Qwen/Qwen2.5-VL-7B-Instruct"
42
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
43
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
44
+ MODEL_ID_X,
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16
47
+ ).to(device).eval()
48
+
49
+ # Load Qwen2.5-VL-3B-Instruct
50
+ MODEL_ID_T = "Qwen/Qwen2.5-VL-3B-Instruct"
51
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
52
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
+ MODEL_ID_T,
54
+ trust_remote_code=True,
55
+ torch_dtype=torch.float16
56
+ ).to(device).eval()
57
+
58
+ def downsample_video(video_path):
59
+ """
60
+ Downsamples the video to evenly spaced frames.
61
+ Each frame is returned as a PIL image along with its timestamp.
62
+ """
63
+ vidcap = cv2.VideoCapture(video_path)
64
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
65
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
66
+ frames = []
67
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
68
+ for i in frame_indices:
69
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
70
+ success, image = vidcap.read()
71
+ if success:
72
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
+ pil_image = Image.fromarray(image)
74
+ timestamp = round(i / fps, 2)
75
+ frames.append((pil_image, timestamp))
76
+ vidcap.release()
77
+ return frames
78
+
79
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
80
+ draw = ImageDraw.Draw(image)
81
+ for box in bounding_boxes:
82
+ xmin, ymin, xmax, ymax = box
83
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
84
+ return image
85
+
86
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
87
+ x_scale = original_width / scaled_width
88
+ y_scale = original_height / scaled_height
89
+ rescaled_boxes = []
90
+ for box in bounding_boxes:
91
+ xmin, ymin, xmax, ymax = box
92
+ rescaled_box = [
93
+ xmin * x_scale,
94
+ ymin * y_scale,
95
+ xmax * x_scale,
96
+ ymax * y_scale
97
+ ]
98
+ rescaled_boxes.append(rescaled_box)
99
+ return rescaled_boxes
100
+
101
+ @spaces.GPU
102
+ def generate_image(model_name: str, text: str, image: Image.Image,
103
+ max_new_tokens: int = 1024,
104
+ temperature: float = 0.6,
105
+ top_p: float = 0.9,
106
+ top_k: int = 50,
107
+ repetition_penalty: float = 1.2):
108
+ """
109
+ Generates responses using the selected model for image input.
110
+ """
111
+ if model_name == "Camel-Doc-OCR-062825":
112
+ processor = processor_m
113
+ model = model_m
114
+ elif model_name == "Qwen2.5-VL-7B-Instruct":
115
+ processor = processor_x
116
+ model = model_x
117
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
118
+ processor = processor_t
119
+ model = model_t
120
+ else:
121
+ yield "Invalid model selected.", "Invalid model selected."
122
+ return
123
+
124
+ if image is None:
125
+ yield "Please upload an image.", "Please upload an image."
126
+ return
127
+
128
+ messages = [{
129
+ "role": "user",
130
+ "content": [
131
+ {"type": "image", "image": image},
132
+ {"type": "text", "text": text},
133
+ ]
134
+ }]
135
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
+ inputs = processor(
137
+ text=[prompt_full],
138
+ images=[image],
139
+ return_tensors="pt",
140
+ padding=True,
141
+ truncation=False,
142
+ max_length=MAX_INPUT_TOKEN_LENGTH
143
+ ).to(device)
144
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
145
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
146
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
147
+ thread.start()
148
+ buffer = ""
149
+ for new_text in streamer:
150
+ buffer += new_text
151
+ time.sleep(0.01)
152
+ yield buffer, buffer
153
+
154
+ @spaces.GPU
155
+ def generate_video(model_name: str, text: str, video_path: str,
156
+ max_new_tokens: int = 1024,
157
+ temperature: float = 0.6,
158
+ top_p: float = 0.9,
159
+ top_k: int = 50,
160
+ repetition_penalty: float = 1.2):
161
+ """
162
+ Generates responses using the selected model for video input.
163
+ """
164
+ if model_name == "Camel-Doc-OCR-062825":
165
+ processor = processor_m
166
+ model = model_m
167
+ elif model_name == "Qwen2.5-VL-7B-Instruct":
168
+ processor = processor_x
169
+ model = model_x
170
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
171
+ processor = processor_t
172
+ model = model_t
173
+ else:
174
+ yield "Invalid model selected.", "Invalid model selected."
175
+ return
176
+
177
+ if video_path is None:
178
+ yield "Please upload a video.", "Please upload a video."
179
+ return
180
+
181
+ frames = downsample_video(video_path)
182
+ messages = [
183
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
184
+ {"role": "user", "content": [{"type": "text", "text": text}]}
185
+ ]
186
+ for frame in frames:
187
+ image, timestamp = frame
188
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
189
+ messages[1]["content"].append({"type": "image", "image": image})
190
+ inputs = processor.apply_chat_template(
191
+ messages,
192
+ tokenize=True,
193
+ add_generation_prompt=True,
194
+ return_dict=True,
195
+ return_tensors="pt",
196
+ truncation=False,
197
+ max_length=MAX_INPUT_TOKEN_LENGTH
198
+ ).to(device)
199
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
200
+ generation_kwargs = {
201
+ **inputs,
202
+ "streamer": streamer,
203
+ "max_new_tokens": max_new_tokens,
204
+ "do_sample": True,
205
+ "temperature": temperature,
206
+ "top_p": top_p,
207
+ "top_k": top_k,
208
+ "repetition_penalty": repetition_penalty,
209
+ }
210
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
211
+ thread.start()
212
+ buffer = ""
213
+ for new_text in streamer:
214
+ buffer += new_text
215
+ buffer = buffer.replace("<|im_end|>", "")
216
+ time.sleep(0.01)
217
+ yield buffer, buffer
218
+
219
+ @spaces.GPU
220
+ def run_object_detection(model_name: str, image: Image.Image, text_input: str, system_prompt: str,
221
+ max_new_tokens: int = 1024,
222
+ temperature: float = 0.6,
223
+ top_p: float = 0.9,
224
+ top_k: int = 50,
225
+ repetition_penalty: float = 1.2):
226
+ if model_name == "Camel-Doc-OCR-062825":
227
+ processor = processor_m
228
+ model = model_m
229
+ elif model_name == "Qwen2.5-VL-7B-Instruct":
230
+ processor = processor_x
231
+ model = model_x
232
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
233
+ processor = processor_t
234
+ model = model_t
235
+ else:
236
+ return "Invalid model selected.", "", image
237
+
238
+ if image is None:
239
+ return "Please upload an image.", "", image
240
+
241
+ messages = [
242
+ {
243
+ "role": "user",
244
+ "content": [
245
+ {"type": "text", "text": system_prompt},
246
+ {"type": "text", "text": text_input},
247
+ {"type": "image", "image": image},
248
+ ],
249
+ }
250
+ ]
251
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
252
+ inputs = processor(
253
+ text=[prompt_full],
254
+ images=[image],
255
+ return_tensors="pt",
256
+ padding=True,
257
+ truncation=False,
258
+ max_length=MAX_INPUT_TOKEN_LENGTH
259
+ ).to(device)
260
+ generation_kwargs = {
261
+ "max_new_tokens": max_new_tokens,
262
+ "do_sample": True,
263
+ "temperature": temperature,
264
+ "top_p": top_p,
265
+ "top_k": top_k,
266
+ "repetition_penalty": repetition_penalty,
267
+ }
268
+ generated_ids = model.generate(**inputs, **generation_kwargs)
269
+ generated_ids_trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
270
+ output_text = processor.batch_decode(
271
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
272
+ )[0]
273
+ pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
274
+ matches = re.findall(pattern, output_text)
275
+ parsed_boxes = [[int(num) for num in match] for match in matches]
276
+ original_width, original_height = image.size
277
+ scaled_boxes = rescale_bounding_boxes(parsed_boxes, original_width, original_height)
278
+ annotated_image = draw_bounding_boxes(image.copy(), scaled_boxes)
279
+ return output_text, str(parsed_boxes), annotated_image
280
+
281
+ # Define examples for image and video inference
282
+ image_examples = [
283
+ ["Convert this page to doc [text] precisely.", "images/3.png"],
284
+ ["Convert this page to doc [text] precisely.", "images/4.png"],
285
+ ["Convert this page to doc [text] precisely.", "images/1.png"],
286
+ ["Convert chart to OTSL.", "images/2.png"]
287
+ ]
288
+
289
+ video_examples = [
290
+ ["Explain the video in detail.", "videos/2.mp4"],
291
+ ["Explain the ad in detail.", "videos/1.mp4"]
292
+ ]
293
+
294
+ # Define examples for object detection
295
+ default_system_prompt = "You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."
296
+ object_detection_examples = [
297
+ ["images/3.png", "Detect all text blocks", default_system_prompt],
298
+ ["images/4.png", "Find all images", default_system_prompt],
299
+ ["images/1.png", "Locate the headers", default_system_prompt],
300
+ ["images/2.png", "Detect the chart", default_system_prompt],
301
+ ]
302
+
303
+ # Added CSS to style the output area as a "Canvas"
304
+ css = """
305
+ .submit-btn {
306
+ background-color: #2980b9 !important;
307
+ color: white !important;
308
+ }
309
+ .submit-btn:hover {
310
+ background-color: #3498db !important;
311
+ }
312
+ .canvas-output {
313
+ border: 2px solid #4682B4;
314
+ border-radius: 10px;
315
+ padding: 20px;
316
+ }
317
+ """
318
+
319
+ # Create the Gradio Interface
320
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
321
+ gr.Markdown("# **[Doc-VLMs-v2-Localization](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
322
+ with gr.Row():
323
+ with gr.Column():
324
+ model_choice = gr.Radio(
325
+ choices=["Camel-Doc-OCR-062825", "Qwen2.5-VL-7B-Instruct", "Qwen2.5-VL-3B-Instruct"],
326
+ label="Select Model",
327
+ value="Camel-Doc-OCR-062825"
328
+ )
329
+ with gr.Tabs():
330
+ with gr.TabItem("Image Inference"):
331
+ with gr.Row():
332
+ with gr.Column():
333
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
334
+ image_upload = gr.Image(type="pil", label="Image")
335
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
336
+ gr.Examples(
337
+ examples=image_examples,
338
+ inputs=[image_query, image_upload]
339
+ )
340
+ with gr.Column():
341
+ output_image = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
342
+ markdown_output_image = gr.Markdown(label="Formatted Result (Result.Md)")
343
+
344
+ with gr.TabItem("Video Inference"):
345
+ with gr.Row():
346
+ with gr.Column():
347
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
348
+ video_upload = gr.Video(label="Video")
349
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
350
+ gr.Examples(
351
+ examples=video_examples,
352
+ inputs=[video_query, video_upload]
353
+ )
354
+ with gr.Column():
355
+ output_video = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
356
+ markdown_output_video = gr.Markdown(label="Formatted Result (Result.Md)")
357
+
358
+ with gr.TabItem("Object Detection"):
359
+ with gr.Row():
360
+ with gr.Column():
361
+ input_img = gr.Image(label="Input Image", type="pil")
362
+ system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt)
363
+ text_input = gr.Textbox(label="User Prompt")
364
+ object_detection_submit = gr.Button("Submit", elem_classes="submit-btn")
365
+ gr.Examples(
366
+ examples=object_detection_examples,
367
+ inputs=[input_img, text_input, system_prompt]
368
+ )
369
+ with gr.Column():
370
+ model_output_text = gr.Textbox(label="Model Output Text")
371
+ parsed_boxes = gr.Textbox(label="Parsed Boxes")
372
+ annotated_image = gr.Image(label="Annotated Image")
373
+
374
+ with gr.Accordion("Advanced options", open=False):
375
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
376
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
377
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
378
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
379
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
380
+
381
+ image_submit.click(
382
+ fn=generate_image,
383
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
384
+ outputs=[output_image, markdown_output_image]
385
+ )
386
+ video_submit.click(
387
+ fn=generate_video,
388
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
389
+ outputs=[output_video, markdown_output_video]
390
+ )
391
+ object_detection_submit.click(
392
+ fn=run_object_detection,
393
+ inputs=[model_choice, input_img, text_input, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
394
+ outputs=[model_output_text, parsed_boxes, annotated_image]
395
+ )
396
+
397
+ if __name__ == "__main__":
398
  demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)