prithivMLmods commited on
Commit
779b488
·
verified ·
1 Parent(s): f5d475f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -61
app.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
 
8
 
9
  import gradio as gr
10
  import spaces
@@ -47,16 +48,6 @@ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
  torch_dtype=torch.float16
48
  ).to(device).eval()
49
 
50
- # Load typhoon-ocr-3b
51
- MODEL_ID_T = "scb10x/typhoon-ocr-3b"
52
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
53
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
- MODEL_ID_T,
55
- trust_remote_code=True,
56
- torch_dtype=torch.float16
57
- ).to(device).eval()
58
-
59
-
60
  # Load Gemma3n-E4B-it
61
  MODEL_ID_G = "google/gemma-3n-E4B-it"
62
  processor_g = AutoProcessor.from_pretrained(MODEL_ID_G, trust_remote_code=True)
@@ -66,29 +57,40 @@ model_g = AutoModelForImageTextToText.from_pretrained(
66
  torch_dtype=torch.float16
67
  ).to(device).eval()
68
 
 
 
 
 
 
 
 
 
 
69
  def downsample_video(video_path):
70
  """
71
- Downsamples the video to evenly spaced frames.
72
- Each frame is returned as a PIL image along with its timestamp.
73
  """
74
  vidcap = cv2.VideoCapture(video_path)
75
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
76
  fps = vidcap.get(cv2.CAP_PROP_FPS)
77
- frames = []
78
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
 
 
79
  for i in frame_indices:
80
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
81
  success, image = vidcap.read()
82
  if success:
83
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
84
- pil_image = Image.fromarray(image)
 
85
  timestamp = round(i / fps, 2)
86
- frames.append((pil_image, timestamp))
87
  vidcap.release()
88
- return frames
89
 
90
  @spaces.GPU
91
- def generate_image(model_name: str, text: str, image: Image.Image,
92
  max_new_tokens: int = 1024,
93
  temperature: float = 0.6,
94
  top_p: float = 0.9,
@@ -103,30 +105,43 @@ def generate_image(model_name: str, text: str, image: Image.Image,
103
  elif model_name == "DREX-062225-7B-exp":
104
  processor = processor_x
105
  model = model_x
106
- elif model_name == "Typhoon-OCR-3B":
107
- processor = processor_t
108
- model = model_t
109
  elif model_name == "Gemma3n-E4B-it":
110
  processor = processor_g
111
  model = model_g
 
 
 
112
  else:
113
  yield "Invalid model selected.", "Invalid model selected."
114
  return
115
 
116
- if image is None:
117
  yield "Please upload an image.", "Please upload an image."
118
  return
119
 
120
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}]
121
- inputs = processor.apply_chat_template(
122
- messages,
123
- tokenize=True,
124
- add_generation_prompt=True,
125
- return_dict=True,
126
- return_tensors="pt",
127
- truncation=False,
128
- max_length=MAX_INPUT_TOKEN_LENGTH
129
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
131
  generation_kwargs = {
132
  **inputs,
@@ -162,12 +177,12 @@ def generate_video(model_name: str, text: str, video_path: str,
162
  elif model_name == "DREX-062225-7B-exp":
163
  processor = processor_x
164
  model = model_x
165
- elif model_name == "Typhoon-OCR-3B":
166
- processor = processor_t
167
- model = model_t
168
  elif model_name == "Gemma3n-E4B-it":
169
  processor = processor_g
170
  model = model_g
 
 
 
171
  else:
172
  yield "Invalid model selected.", "Invalid model selected."
173
  return
@@ -176,26 +191,35 @@ def generate_video(model_name: str, text: str, video_path: str,
176
  yield "Please upload a video.", "Please upload a video."
177
  return
178
 
179
- frames = downsample_video(video_path)
180
  content = [{"type": "text", "text": text}]
181
- if model_name == "Gemma3n-E4B-it":
182
- for frame, _ in frames:
183
- content.append({"type": "image", "image": frame})
184
- else:
185
- for frame in frames:
186
- image, timestamp = frame
187
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
188
- content.append({"type": "image", "image": image})
189
  messages = [{"role": "user", "content": content}]
190
- inputs = processor.apply_chat_template(
191
- messages,
192
- tokenize=True,
193
- add_generation_prompt=True,
194
- return_dict=True,
195
- return_tensors="pt",
196
- truncation=False,
197
- max_length=MAX_INPUT_TOKEN_LENGTH
198
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
200
  generation_kwargs = {
201
  **inputs,
@@ -253,7 +277,7 @@ with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
253
  with gr.Tabs():
254
  with gr.TabItem("Image Inference"):
255
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
256
- image_upload = gr.Image(type="pil", label="Image")
257
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
258
  gr.Examples(
259
  examples=image_examples,
@@ -281,18 +305,11 @@ with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
281
  markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
282
 
283
  model_choice = gr.Radio(
284
- choices=["DREX-062225-7B-exp", "VIREX-062225-7B-exp", "Typhoon-OCR-3B", "Gemma3n-E4B-it"],
285
  label="Select Model",
286
  value="DREX-062225-7B-exp"
287
  )
288
 
289
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs/discussions)")
290
- gr.Markdown("> [DREX-062225-7B-exp](https://huggingface.co/prithivMLmods/DREX-062225-exp): the drex-062225-exp (document retrieval and extraction expert) model is a specialized fine-tuned version of docscopeocr-7b-050425-exp, optimized for document retrieval, content extraction, and analysis recognition. built on top of the qwen2.5-vl architecture.")
291
- gr.Markdown("> [VIREX-062225-7B-exp](https://huggingface.co/prithivMLmods/VIREX-062225-exp): the virex-062225-exp (video information retrieval and extraction expert - experimental) model is a fine-tuned version of qwen2.5-vl-7b-instruct, specifically optimized for advanced video understanding, image comprehension, sense of reasoning, and natural language decision-making through cot reasoning.")
292
- gr.Markdown("> [Typhoon-OCR-3B](https://huggingface.co/scb10x/typhoon-ocr-3b): a bilingual document parsing model built specifically for real-world documents in thai and english, inspired by models like olmocr, based on qwen2.5-vl-instruction. this model is intended to be used with a specific prompt only.")
293
- gr.Markdown("> [Gemma3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it): A multimodal model capable of processing images and videos for various tasks.")
294
- gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
295
-
296
  image_submit.click(
297
  fn=generate_image,
298
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
5
  import time
6
  import asyncio
7
  from threading import Thread
8
+ import tempfile
9
 
10
  import gradio as gr
11
  import spaces
 
48
  torch_dtype=torch.float16
49
  ).to(device).eval()
50
 
 
 
 
 
 
 
 
 
 
 
51
  # Load Gemma3n-E4B-it
52
  MODEL_ID_G = "google/gemma-3n-E4B-it"
53
  processor_g = AutoProcessor.from_pretrained(MODEL_ID_G, trust_remote_code=True)
 
57
  torch_dtype=torch.float16
58
  ).to(device).eval()
59
 
60
+ # Load Gemma3n-E2B-it
61
+ MODEL_ID_N = "google/gemma-3n-E2B-it"
62
+ processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
63
+ model_n = AutoModelForImageTextToText.from_pretrained(
64
+ MODEL_ID_N,
65
+ trust_remote_code=True,
66
+ torch_dtype=torch.float16
67
+ ).to(device).eval()
68
+
69
  def downsample_video(video_path):
70
  """
71
+ Downsamples the video to evenly spaced frames and saves them to temporary files.
72
+ Returns a list of (frame_path, timestamp) and the temp directory.
73
  """
74
  vidcap = cv2.VideoCapture(video_path)
75
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
76
  fps = vidcap.get(cv2.CAP_PROP_FPS)
 
77
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
78
+ temp_dir = tempfile.mkdtemp()
79
+ frames = []
80
  for i in frame_indices:
81
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
82
  success, image = vidcap.read()
83
  if success:
84
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
+ frame_path = os.path.join(temp_dir, f"frame_{i}.jpg")
86
+ Image.fromarray(image).save(frame_path)
87
  timestamp = round(i / fps, 2)
88
+ frames.append((frame_path, timestamp))
89
  vidcap.release()
90
+ return frames, temp_dir
91
 
92
  @spaces.GPU
93
+ def generate_image(model_name: str, text: str, image_path: str,
94
  max_new_tokens: int = 1024,
95
  temperature: float = 0.6,
96
  top_p: float = 0.9,
 
105
  elif model_name == "DREX-062225-7B-exp":
106
  processor = processor_x
107
  model = model_x
 
 
 
108
  elif model_name == "Gemma3n-E4B-it":
109
  processor = processor_g
110
  model = model_g
111
+ elif model_name == "Gemma3n-E2B-it":
112
+ processor = processor_n
113
+ model = model_n
114
  else:
115
  yield "Invalid model selected.", "Invalid model selected."
116
  return
117
 
118
+ if image_path is None:
119
  yield "Please upload an image.", "Please upload an image."
120
  return
121
 
122
+ messages = [{"role": "user", "content": [{"type": "text", "text": text}, {"type": "image", "image": image_path}]}]
123
+
124
+ if model_name in ["Gemma3n-E4B-it", "Gemma3n-E2B-it"]:
125
+ inputs = processor.apply_chat_template(
126
+ messages,
127
+ tokenize=True,
128
+ add_generation_prompt=True,
129
+ return_dict=True,
130
+ return_tensors="pt",
131
+ truncation=False,
132
+ max_length=MAX_INPUT_TOKEN_LENGTH
133
+ ).to(device)
134
+ else:
135
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
+ inputs = processor(
137
+ text=[prompt_full],
138
+ images=[image_path],
139
+ return_tensors="pt",
140
+ padding=True,
141
+ truncation=False,
142
+ max_length=MAX_INPUT_TOKEN_LENGTH
143
+ ).to(device)
144
+
145
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
146
  generation_kwargs = {
147
  **inputs,
 
177
  elif model_name == "DREX-062225-7B-exp":
178
  processor = processor_x
179
  model = model_x
 
 
 
180
  elif model_name == "Gemma3n-E4B-it":
181
  processor = processor_g
182
  model = model_g
183
+ elif model_name == "Gemma3n-E2B-it":
184
+ processor = processor_n
185
+ model = model_n
186
  else:
187
  yield "Invalid model selected.", "Invalid model selected."
188
  return
 
191
  yield "Please upload a video.", "Please upload a video."
192
  return
193
 
194
+ frames, temp_dir = downsample_video(video_path)
195
  content = [{"type": "text", "text": text}]
196
+ for frame_path, timestamp in frames:
197
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
198
+ content.append({"type": "image", "image": frame_path})
 
 
 
 
 
199
  messages = [{"role": "user", "content": content}]
200
+
201
+ if model_name in ["Gemma3n-E4B-it", "Gemma3n-E2B-it"]:
202
+ inputs = processor.apply_chat_template(
203
+ messages,
204
+ tokenize=True,
205
+ add_generation_prompt=True,
206
+ return_dict=True,
207
+ return_tensors="pt",
208
+ truncation=False,
209
+ max_length=MAX_INPUT_TOKEN_LENGTH
210
+ ).to(device)
211
+ else:
212
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
213
+ images = [frame_path for frame_path, _ in frames]
214
+ inputs = processor(
215
+ text=[prompt_full],
216
+ images=images,
217
+ return_tensors="pt",
218
+ padding=True,
219
+ truncation=False,
220
+ max_length=MAX_INPUT_TOKEN_LENGTH
221
+ ).to(device)
222
+
223
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
224
  generation_kwargs = {
225
  **inputs,
 
277
  with gr.Tabs():
278
  with gr.TabItem("Image Inference"):
279
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
280
+ image_upload = gr.Image(type="filepath", label="Image")
281
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
282
  gr.Examples(
283
  examples=image_examples,
 
305
  markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
306
 
307
  model_choice = gr.Radio(
308
+ choices=["DREX-062225-7B-exp", "VIREX-062225-7B-exp", "Gemma3n-E4B-it", "Gemma3n-E2B-it"],
309
  label="Select Model",
310
  value="DREX-062225-7B-exp"
311
  )
312
 
 
 
 
 
 
 
 
313
  image_submit.click(
314
  fn=generate_image,
315
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],