prithivMLmods commited on
Commit
a26c9d1
·
verified ·
1 Parent(s): bcbc345

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -122
app.py CHANGED
@@ -1,53 +1,79 @@
1
- import gradio as gr
2
- from transformers.image_utils import load_image
3
- from threading import Thread
 
4
  import time
5
- import torch
 
 
 
6
  import spaces
7
- import cv2
8
  import numpy as np
9
  from PIL import Image
 
 
10
  from transformers import (
11
  Qwen2VLForConditionalGeneration,
 
12
  AutoProcessor,
13
  TextIteratorStreamer,
14
  )
15
- from transformers import Qwen2_5_VLForConditionalGeneration
16
 
17
- # Helper Functions
18
- def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
19
- """
20
- Returns an HTML snippet for a thin animated progress bar with a label.
21
- Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
22
- """
23
- return f'''
24
- <div style="display: flex; align-items: center;">
25
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
26
- <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
27
- <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
28
- </div>
29
- </div>
30
- <style>
31
- @keyframes loading {{
32
- 0% {{ transform: translateX(-100%); }}
33
- 100% {{ transform: translateX(100%); }}
34
- }}
35
- </style>
36
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def downsample_video(video_path):
39
  """
40
- Downsamples a video file by extracting 10 evenly spaced frames.
41
- Returns a list of tuples (PIL.Image, timestamp).
42
  """
43
  vidcap = cv2.VideoCapture(video_path)
44
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
45
  fps = vidcap.get(cv2.CAP_PROP_FPS)
46
  frames = []
47
- if total_frames <= 0 or fps <= 0:
48
- vidcap.release()
49
- return frames
50
- frame_indices = np.linspace(0, total_frames - 1, 25, dtype=int)
51
  for i in frame_indices:
52
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
53
  success, image = vidcap.read()
@@ -59,116 +85,202 @@ def downsample_video(video_path):
59
  vidcap.release()
60
  return frames
61
 
62
- # Model and Processor Setup
63
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
64
- qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
65
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
66
- QV_MODEL_ID,
67
- trust_remote_code=True,
68
- torch_dtype=torch.float16
69
- ).to("cuda").eval()
70
-
71
- ROLMOCR_MODEL_ID = "reducto/RolmOCR"
72
- rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
73
- rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
- ROLMOCR_MODEL_ID,
75
- trust_remote_code=True,
76
- torch_dtype=torch.bfloat16
77
- ).to("cuda").eval()
78
-
79
- # Main Inference Function
80
  @spaces.GPU
81
- @torch.no_grad()
82
- def model_inference(input_dict, history, use_rolmocr=False):
83
- text = input_dict["text"].strip()
84
- files = input_dict.get("files", [])
85
-
86
- if not text and not files:
87
- yield "Error: Please input a text query or provide files (images or videos)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return
89
 
90
- # Process files: images and videos
91
- image_list = []
92
- for idx, file in enumerate(files):
93
- if file.lower().endswith((".mp4", ".avi", ".mov")):
94
- frames = downsample_video(file)
95
- if not frames:
96
- yield "Error: Could not extract frames from the video."
97
- return
98
- for frame, timestamp in frames:
99
- label = f"Video {idx+1} Frame {timestamp}:"
100
- image_list.append((label, frame))
101
- else:
102
- try:
103
- img = load_image(file)
104
- label = f"Image {idx+1}:"
105
- image_list.append((label, img))
106
- except Exception as e:
107
- yield f"Error loading image: {str(e)}"
108
- return
109
-
110
- # Build content list
111
- content = [{"type": "text", "text": text}]
112
- for label, img in image_list:
113
- content.append({"type": "text", "text": label})
114
- content.append({"type": "image", "image": img})
115
-
116
- messages = [{"role": "user", "content": content}]
117
-
118
- # Select processor and model
119
- processor = rolmocr_processor if use_rolmocr else qwen_processor
120
- model = rolmocr_model if use_rolmocr else qwen_model
121
- model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
122
 
 
 
 
 
 
 
 
123
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
- all_images = [item["image"] for item in content if item["type"] == "image"]
125
  inputs = processor(
126
  text=[prompt_full],
127
- images=all_images if all_images else None,
128
  return_tensors="pt",
129
  padding=True,
130
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
133
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
 
 
 
 
 
 
 
 
 
134
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
135
  thread.start()
136
  buffer = ""
137
- yield progress_bar_html(f"Processing with {model_name}")
138
  for new_text in streamer:
139
  buffer += new_text
140
- buffer = buffer.replace("<|im_end|>", "")
141
  time.sleep(0.01)
142
  yield buffer
143
 
144
- # Gradio Interface
145
- examples = [
146
- [{"text": "OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
147
- [{"text": "Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
148
- [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
149
  ]
150
 
151
- demo = gr.ChatInterface(
152
- fn=model_inference,
153
- description="# **[Multimodal OCR](https://huggingface.co/prithivMLmods/Qwen2-VL-OCR-2B-Instruct)**",
154
- examples=examples,
155
- textbox=gr.MultimodalTextbox(
156
- label="Query Input",
157
- file_types=["image", "video"],
158
- file_count="multiple",
159
- placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox."
160
- ),
161
- stop_btn="Stop Generation",
162
- multimodal=True,
163
- cache_examples=False,
164
- theme="bethecloud/storj_theme",
165
- additional_inputs=[
166
- gr.Checkbox(
167
- label="Use RolmOCR",
168
- value=False,
169
- info="Check to use RolmOCR, uncheck to use Qwen2VL OCR"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
- ],
172
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- demo.launch(share=True, mcp_server=True, debug=True, ssr_mode=False)
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
  import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
  import spaces
11
+ import torch
12
  import numpy as np
13
  from PIL import Image
14
+ import cv2
15
+
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
+ Qwen2_5_VLForConditionalGeneration,
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
22
+ from transformers.image_utils import load_image
23
 
24
+ # Constants for text generation
25
+ MAX_MAX_NEW_TOKENS = 2048
26
+ DEFAULT_MAX_NEW_TOKENS = 1024
27
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
28
+
29
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Load Cosmos-Reason1-7B
32
+ MODEL_ID_M = "reducto/RolmOCR"
33
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
34
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ MODEL_ID_M,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16
38
+ ).to(device).eval()
39
+
40
+ # Load DocScope
41
+ MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
42
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
43
+ model_x = Qwen2VLForConditionalGeneration.from_pretrained(
44
+ MODEL_ID_X,
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16
47
+ ).to(device).eval()
48
+
49
+ # Load Relaxed
50
+ MODEL_ID_Z = "lingshu-medical-mllm/Lingshu-7B"
51
+ processor_z = AutoProcessor.from_pretrained(MODEL_ID_Z, trust_remote_code=True)
52
+ model_z = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
+ MODEL_ID_Z,
54
+ trust_remote_code=True,
55
+ torch_dtype=torch.float16
56
+ ).to(device).eval()
57
+
58
+ # Load visionOCR
59
+ MODEL_ID_V = "nanonets/Nanonets-OCR-s"
60
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
61
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ MODEL_ID_V,
63
+ trust_remote_code=True,
64
+ torch_dtype=torch.float16
65
+ ).to(device).eval()
66
 
67
  def downsample_video(video_path):
68
  """
69
+ Downsamples the video to evenly spaced frames.
70
+ Each frame is returned as a PIL image along with its timestamp.
71
  """
72
  vidcap = cv2.VideoCapture(video_path)
73
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
74
  fps = vidcap.get(cv2.CAP_PROP_FPS)
75
  frames = []
76
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
 
 
 
77
  for i in frame_indices:
78
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
79
  success, image = vidcap.read()
 
85
  vidcap.release()
86
  return frames
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @spaces.GPU
89
+ def generate_image(model_name: str, text: str, image: Image.Image,
90
+ max_new_tokens: int = 1024,
91
+ temperature: float = 0.6,
92
+ top_p: float = 0.9,
93
+ top_k: int = 50,
94
+ repetition_penalty: float = 1.2):
95
+ """
96
+ Generates responses using the selected model for image input.
97
+ """
98
+ if model_name == "RolmOCR":
99
+ processor = processor_m
100
+ model = model_m
101
+ elif model_name == "Qwen2-VL-OCR-2B-Instruct":
102
+ processor = processor_x
103
+ model = model_x
104
+ elif model_name == "Lingshu-7B":
105
+ processor = processor_z
106
+ model = model_z
107
+ elif model_name == "Nanonets-OCR-s":
108
+ processor = processor_v
109
+ model = model_v
110
+ else:
111
+ yield "Invalid model selected."
112
  return
113
 
114
+ if image is None:
115
+ yield "Please upload an image."
116
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ messages = [{
119
+ "role": "user",
120
+ "content": [
121
+ {"type": "image", "image": image},
122
+ {"type": "text", "text": text},
123
+ ]
124
+ }]
125
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
126
  inputs = processor(
127
  text=[prompt_full],
128
+ images=[image],
129
  return_tensors="pt",
130
  padding=True,
131
+ truncation=False,
132
+ max_length=MAX_INPUT_TOKEN_LENGTH
133
+ ).to(device)
134
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
135
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
136
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
137
+ thread.start()
138
+ buffer = ""
139
+ for new_text in streamer:
140
+ buffer += new_text
141
+ time.sleep(0.01)
142
+ yield buffer
143
+
144
+ @spaces.GPU
145
+ def generate_video(model_name: str, text: str, video_path: str,
146
+ max_new_tokens: int = 1024,
147
+ temperature: float = 0.6,
148
+ top_p: float = 0.9,
149
+ top_k: int = 50,
150
+ repetition_penalty: float = 1.2):
151
+ """
152
+ Generates responses using the selected model for video input.
153
+ """
154
+ if model_name == "RolmOCR":
155
+ processor = processor_m
156
+ model = model_m
157
+ elif model_name == "Qwen2-VL-OCR-2B-Instruct":
158
+ processor = processor_x
159
+ model = model_x
160
+ elif model_name == "Lingshu-7B":
161
+ processor = processor_z
162
+ model = model_z
163
+ elif model_name == "Nanonets-OCR-s":
164
+ processor = processor_v
165
+ model = model_v
166
+ else:
167
+ yield "Invalid model selected."
168
+ return
169
 
170
+ if video_path is None:
171
+ yield "Please upload a video."
172
+ return
173
+
174
+ frames = downsample_video(video_path)
175
+ messages = [
176
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
177
+ {"role": "user", "content": [{"type": "text", "text": text}]}
178
+ ]
179
+ for frame in frames:
180
+ image, timestamp = frame
181
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
182
+ messages[1]["content"].append({"type": "image", "image": image})
183
+ inputs = processor.apply_chat_template(
184
+ messages,
185
+ tokenize=True,
186
+ add_generation_prompt=True,
187
+ return_dict=True,
188
+ return_tensors="pt",
189
+ truncation=False,
190
+ max_length=MAX_INPUT_TOKEN_LENGTH
191
+ ).to(device)
192
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
193
+ generation_kwargs = {
194
+ **inputs,
195
+ "streamer": streamer,
196
+ "max_new_tokens": max_new_tokens,
197
+ "do_sample": True,
198
+ "temperature": temperature,
199
+ "top_p": top_p,
200
+ "top_k": top_k,
201
+ "repetition_penalty": repetition_penalty,
202
+ }
203
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
204
  thread.start()
205
  buffer = ""
 
206
  for new_text in streamer:
207
  buffer += new_text
 
208
  time.sleep(0.01)
209
  yield buffer
210
 
211
+ # Define examples for image and video inference
212
+ image_examples = [
213
+ ["Perform OCR on the Image.", "images/1.jpg"],
214
+ ["Extract the table content", "images/2.png"]
 
215
  ]
216
 
217
+ video_examples = [
218
+ ["Explain the watch ad in detail.", "videos/1.mp4"],
219
+ ["Identify the main actions in the cartoon video", "videos/2.mp4"]
220
+ ]
221
+
222
+ css = """
223
+ .submit-btn {
224
+ background-color: #2980b9 !important;
225
+ color: white !important;
226
+ }
227
+ .submit-btn:hover {
228
+ background-color: #3498db !important;
229
+ }
230
+ """
231
+
232
+ # Create the Gradio Interface
233
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
234
+ gr.Markdown("# **[Multimodal OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
235
+ with gr.Row():
236
+ with gr.Column():
237
+ with gr.Tabs():
238
+ with gr.TabItem("Image Inference"):
239
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
240
+ image_upload = gr.Image(type="pil", label="Image")
241
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
242
+ gr.Examples(
243
+ examples=image_examples,
244
+ inputs=[image_query, image_upload]
245
+ )
246
+ with gr.TabItem("Video Inference"):
247
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
248
+ video_upload = gr.Video(label="Video")
249
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
250
+ gr.Examples(
251
+ examples=video_examples,
252
+ inputs=[video_query, video_upload]
253
+ )
254
+ with gr.Accordion("Advanced options", open=False):
255
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
256
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
257
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
258
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
259
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
260
+ with gr.Column():
261
+ output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
262
+ model_choice = gr.Radio(
263
+ choices=["Nanonets-OCR-s", "Qwen2-VL-OCR-2B-Instruct", "RolmOCR", "Lingshu-7B"],
264
+ label="Select Model",
265
+ value="RolmOCR"
266
  )
267
+
268
+ gr.Markdown("**Model Info**")
269
+ gr.Markdown("⤷ [Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s): nanonets-ocr-s is a powerful, state-of-the-art image-to-markdown ocr model that goes far beyond traditional text extraction. it transforms documents into structured markdown with intelligent content recognition and semantic tagging.")
270
+ gr.Markdown("⤷ [Qwen2-VL-OCR-2B-Instruct](https://huggingface.co/prithivMLmods/Qwen2-VL-OCR-2B-Instruct): qwen2-vl-ocr-2b-instruct model is a fine-tuned version of qwen/qwen2-vl-2b-instruct, tailored for tasks that involve <messy> optical character recognition (ocr), image-to-text conversion, and math problem solving with latex formatting.")
271
+ gr.Markdown("⤷ [RolmOCR](https://huggingface.co/reducto/RolmOCR): rolmocr, high-quality, openly available approach to parsing pdfs and other complex documents oprical character recognition. it is designed to handle a wide range of document types, including scanned documents, handwritten text, and complex layouts.")
272
+ gr.Markdown("⤷ [Lingshu-7B](https://huggingface.co/lingshu-medical-mllm/Lingshu-7B): lingshu-7b is a generalist foundation model for unified multimodal medical understanding and reasoning, virtual assistants, and content generation.")
273
+
274
+ image_submit.click(
275
+ fn=generate_image,
276
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
277
+ outputs=output
278
+ )
279
+ video_submit.click(
280
+ fn=generate_video,
281
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
282
+ outputs=output
283
+ )
284
 
285
+ if __name__ == "__main__":
286
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)