prithivMLmods commited on
Commit
6401487
·
verified ·
1 Parent(s): 799c106

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -19
app.py CHANGED
@@ -1,10 +1,4 @@
1
  import gradio as gr
2
- from transformers import (
3
- Qwen2VLForConditionalGeneration,
4
- AutoProcessor,
5
- TextIteratorStreamer,
6
- AutoModelForImageTextToText,
7
- )
8
  from transformers.image_utils import load_image
9
  from threading import Thread
10
  import time
@@ -13,6 +7,14 @@ import spaces
13
  from PIL import Image
14
  import requests
15
  from io import BytesIO
 
 
 
 
 
 
 
 
16
 
17
  # Helper function to return a progress bar HTML snippet.
18
  def progress_bar_html(label: str) -> str:
@@ -20,7 +22,7 @@ def progress_bar_html(label: str) -> str:
20
  <div style="display: flex; align-items: center;">
21
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
22
  <div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;">
23
- <div style="width: 100%; height: 100%; background-color: #FF69B4 ; animation: loading 1.5s linear infinite;"></div>
24
  </div>
25
  </div>
26
  <style>
@@ -31,7 +33,29 @@ def progress_bar_html(label: str) -> str:
31
  </style>
32
  '''
33
 
34
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or use #prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
36
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
37
  QV_MODEL_ID,
@@ -39,25 +63,77 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
39
  torch_dtype=torch.float16
40
  ).to("cuda").eval()
41
 
 
42
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
43
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
44
  aya_model = AutoModelForImageTextToText.from_pretrained(
45
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
46
  )
47
 
 
 
 
48
  @spaces.GPU
49
  def model_inference(input_dict, history):
50
  text = input_dict["text"].strip()
51
  files = input_dict.get("files", [])
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if text.lower().startswith("@aya-vision"):
54
- # Remove the command prefix and trim the prompt.
55
  text_prompt = text[len("@aya-vision"):].strip()
56
  if not files:
57
  yield "Error: Please provide an image for the @aya-vision feature."
58
  return
59
  else:
60
- # For simplicity, use the first provided image.
61
  image = load_image(files[0])
62
  yield progress_bar_html("Processing with Aya-Vision-8b")
63
  messages = [{
@@ -75,7 +151,6 @@ def model_inference(input_dict, history):
75
  return_dict=True,
76
  return_tensors="pt"
77
  ).to(aya_model.device)
78
- # Set up a streamer for Aya-Vision output
79
  streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
80
  generation_kwargs = dict(
81
  inputs,
@@ -94,7 +169,7 @@ def model_inference(input_dict, history):
94
  yield buffer
95
  return
96
 
97
- # Load images if provided.
98
  if len(files) > 1:
99
  images = [load_image(image) for image in files]
100
  elif len(files) == 1:
@@ -102,7 +177,6 @@ def model_inference(input_dict, history):
102
  else:
103
  images = []
104
 
105
- # Validate input: require both text and (optionally) image(s).
106
  if text == "" and not images:
107
  yield "Error: Please input a query and optionally image(s)."
108
  return
@@ -110,7 +184,6 @@ def model_inference(input_dict, history):
110
  yield "Error: Please input a text query along with the image(s)."
111
  return
112
 
113
- # Prepare messages for the Qwen2-VL model.
114
  messages = [{
115
  "role": "user",
116
  "content": [
@@ -129,11 +202,9 @@ def model_inference(input_dict, history):
129
  padding=True,
130
  ).to("cuda")
131
 
132
- # Set up a streamer for real-time output.
133
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
 
136
- # Start generation in a separate thread.
137
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
138
  thread.start()
139
 
@@ -145,7 +216,11 @@ def model_inference(input_dict, history):
145
  time.sleep(0.01)
146
  yield buffer
147
 
 
 
 
148
  examples = [
 
149
  [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
150
  [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
151
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
@@ -160,13 +235,13 @@ examples = [
160
 
161
  demo = gr.ChatInterface(
162
  fn=model_inference,
163
- description="# **Multimodal OCR `@aya-vision 'prompt..'`**",
164
  examples=examples,
165
  textbox=gr.MultimodalTextbox(
166
  label="Query Input",
167
- file_types=["image"],
168
  file_count="multiple",
169
- placeholder="By default, it runs Qwen2VL OCR, Tag @aya-vision for Aya Vision 8B"
170
  ),
171
  stop_btn="Stop Generation",
172
  multimodal=True,
 
1
  import gradio as gr
 
 
 
 
 
 
2
  from transformers.image_utils import load_image
3
  from threading import Thread
4
  import time
 
7
  from PIL import Image
8
  import requests
9
  from io import BytesIO
10
+ import cv2
11
+ import numpy as np
12
+ from transformers import (
13
+ Qwen2VLForConditionalGeneration,
14
+ AutoProcessor,
15
+ TextIteratorStreamer,
16
+ AutoModelForImageTextToText,
17
+ )
18
 
19
  # Helper function to return a progress bar HTML snippet.
20
  def progress_bar_html(label: str) -> str:
 
22
  <div style="display: flex; align-items: center;">
23
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
24
  <div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;">
25
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
26
  </div>
27
  </div>
28
  <style>
 
33
  </style>
34
  '''
35
 
36
+ # Helper function to downsample a video into 10 evenly spaced frames.
37
+ def downsample_video(video_path):
38
+ vidcap = cv2.VideoCapture(video_path)
39
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
40
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
41
+ frames = []
42
+ # Sample 10 evenly spaced frames.
43
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
44
+ for i in frame_indices:
45
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
46
+ success, image = vidcap.read()
47
+ if success:
48
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
49
+ pil_image = Image.fromarray(image)
50
+ timestamp = round(i / fps, 2)
51
+ frames.append((pil_image, timestamp))
52
+ vidcap.release()
53
+ return frames
54
+
55
+ # Model and processor setups
56
+
57
+ # Setup for Qwen2VL OCR branch (default).
58
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or use "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
59
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
60
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
61
  QV_MODEL_ID,
 
63
  torch_dtype=torch.float16
64
  ).to("cuda").eval()
65
 
66
+ # Setup for Aya-Vision branch.
67
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
68
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
69
  aya_model = AutoModelForImageTextToText.from_pretrained(
70
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
71
  )
72
 
73
+ # ---------------------------
74
+ # Main Inference Function
75
+ # ---------------------------
76
  @spaces.GPU
77
  def model_inference(input_dict, history):
78
  text = input_dict["text"].strip()
79
  files = input_dict.get("files", [])
80
 
81
+ # Branch for video inference with Aya-Vision using @video-infer.
82
+ if text.lower().startswith("@video-infer"):
83
+ prompt = text[len("@video-infer"):].strip()
84
+ if not files:
85
+ yield "Error: Please provide a video for the @video-infer feature."
86
+ return
87
+ video_path = files[0]
88
+ frames = downsample_video(video_path)
89
+ if not frames:
90
+ yield "Error: Could not extract frames from the video."
91
+ return
92
+ # Build messages: start with the prompt then add each frame with its timestamp.
93
+ content_list = []
94
+ content_list.append({"type": "text", "text": prompt})
95
+ for frame, timestamp in frames:
96
+ content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
97
+ content_list.append({"type": "image", "image": frame})
98
+ messages = [{
99
+ "role": "user",
100
+ "content": content_list,
101
+ }]
102
+ inputs = aya_processor.apply_chat_template(
103
+ messages,
104
+ padding=True,
105
+ add_generation_prompt=True,
106
+ tokenize=True,
107
+ return_dict=True,
108
+ return_tensors="pt"
109
+ ).to(aya_model.device)
110
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
111
+ generation_kwargs = dict(
112
+ inputs,
113
+ streamer=streamer,
114
+ max_new_tokens=1024,
115
+ do_sample=True,
116
+ temperature=0.3
117
+ )
118
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
119
+ thread.start()
120
+ buffer = ""
121
+ yield progress_bar_html("Processing video with Aya-Vision-8b")
122
+ for new_text in streamer:
123
+ buffer += new_text
124
+ buffer = buffer.replace("<|im_end|>", "")
125
+ time.sleep(0.01)
126
+ yield buffer
127
+ return
128
+
129
+ # Branch for single image inference with Aya-Vision using @aya-vision.
130
  if text.lower().startswith("@aya-vision"):
 
131
  text_prompt = text[len("@aya-vision"):].strip()
132
  if not files:
133
  yield "Error: Please provide an image for the @aya-vision feature."
134
  return
135
  else:
136
+ # Use the first provided image.
137
  image = load_image(files[0])
138
  yield progress_bar_html("Processing with Aya-Vision-8b")
139
  messages = [{
 
151
  return_dict=True,
152
  return_tensors="pt"
153
  ).to(aya_model.device)
 
154
  streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
155
  generation_kwargs = dict(
156
  inputs,
 
169
  yield buffer
170
  return
171
 
172
+ # Default branch: Use Qwen2VL OCR for text (with optional images).
173
  if len(files) > 1:
174
  images = [load_image(image) for image in files]
175
  elif len(files) == 1:
 
177
  else:
178
  images = []
179
 
 
180
  if text == "" and not images:
181
  yield "Error: Please input a query and optionally image(s)."
182
  return
 
184
  yield "Error: Please input a text query along with the image(s)."
185
  return
186
 
 
187
  messages = [{
188
  "role": "user",
189
  "content": [
 
202
  padding=True,
203
  ).to("cuda")
204
 
 
205
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
206
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
207
 
 
208
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
209
  thread.start()
210
 
 
216
  time.sleep(0.01)
217
  yield buffer
218
 
219
+
220
+ # Gradio Interface Setup
221
+
222
  examples = [
223
+ [{"text": "@video-infer Summarize the video content", "files": ["examples/videoplayback.mp4"]}],
224
  [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
225
  [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
226
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
 
235
 
236
  demo = gr.ChatInterface(
237
  fn=model_inference,
238
+ description="# **Multimodal OCR and Video Inference with Aya-Vision (@aya-vision for image, @video-infer for video) and Qwen2VL OCR (default)**",
239
  examples=examples,
240
  textbox=gr.MultimodalTextbox(
241
  label="Query Input",
242
+ file_types=["image", "video"],
243
  file_count="multiple",
244
+ placeholder="Tag @aya-vision for Aya-Vision image infer, @video-infer for Aya-Vision video infer, default runs Qwen2VL OCR"
245
  ),
246
  stop_btn="Stop Generation",
247
  multimodal=True,