prithivMLmods commited on
Commit
c7906eb
·
verified ·
1 Parent(s): 1366989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -315
app.py CHANGED
@@ -1,343 +1,220 @@
1
- #!/usr/bin/env python
 
2
 
3
- import os
4
- import re
5
- import tempfile
6
- from collections.abc import Iterator
7
- from threading import Thread
 
 
8
 
 
 
 
 
9
  import cv2
10
- import gradio as gr
11
- import spaces
12
  import torch
13
- from loguru import logger
14
  from PIL import Image
15
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
-
17
- model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
- model = Gemma3ForConditionalGeneration.from_pretrained(
20
- model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
- )
22
-
23
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
-
25
-
26
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
- image_count = 0
28
- video_count = 0
29
- for path in paths:
30
- if path.endswith(".mp4"):
31
- video_count += 1
32
- else:
33
- image_count += 1
34
- return image_count, video_count
35
 
36
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
37
- image_count = 0
38
- video_count = 0
39
- for item in history:
40
- if item["role"] != "user" or isinstance(item["content"], str):
41
- continue
42
- if item["content"][0].endswith(".mp4"):
43
- video_count += 1
44
- else:
45
- image_count += 1
46
- return image_count, video_count
47
 
48
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
49
- new_image_count, new_video_count = count_files_in_new_message(message["files"])
50
- history_image_count, history_video_count = count_files_in_history(history)
51
- image_count = history_image_count + new_image_count
52
- video_count = history_video_count + new_video_count
53
- if video_count > 1:
54
- gr.Warning("Only one video is supported.")
55
- return False
56
- if video_count == 1:
57
- if image_count > 0:
58
- gr.Warning("Mixing images and videos is not allowed.")
59
- return False
60
- if "<image>" in message["text"]:
61
- gr.Warning("Using <image> tags with video files is not supported.")
62
- return False
63
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
64
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
65
- return False
66
- if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
67
- gr.Warning("The number of <image> tags in the text does not match the number of images.")
68
- return False
69
- return True
70
 
71
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  vidcap = cv2.VideoCapture(video_path)
73
- fps = vidcap.get(cv2.CAP_PROP_FPS)
74
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
75
-
76
- max_frames = 5 # Limit to 5 frames to prevent memory overload
77
- if total_frames <= max_frames:
78
- indices = list(range(total_frames))
79
- else:
80
- indices = [int(i * (total_frames - 1) / (max_frames - 1)) for i in range(max_frames)]
81
-
82
  frames = []
83
- for i in indices:
 
 
84
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
85
- success, image = vidcap.read()
86
  if success:
87
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
- pil_image = Image.fromarray(image)
89
  timestamp = round(i / fps, 2)
90
  frames.append((pil_image, timestamp))
91
-
92
  vidcap.release()
93
  return frames
94
 
95
- def process_video(video_path: str) -> list[dict]:
96
- content = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  frames = downsample_video(video_path)
98
- for frame in frames:
99
- pil_image, timestamp = frame
100
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
101
- pil_image.save(temp_file.name)
102
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
103
- content.append({"type": "image", "url": temp_file.name})
104
- logger.debug(f"{content=}")
105
- return content
106
-
107
- def process_interleaved_images(message: dict) -> list[dict]:
108
- logger.debug(f"{message['files']=}")
109
- parts = re.split(r"(<image>)", message["text"])
110
- logger.debug(f"{parts=}")
111
-
112
- content = []
113
- image_index = 0
114
- for part in parts:
115
- logger.debug(f"{part=}")
116
- if part == "<image>":
117
- content.append({"type": "image", "url": message["files"][image_index]})
118
- logger.debug(f"file: {message['files'][image_index]}")
119
- image_index += 1
120
- elif part.strip():
121
- content.append({"type": "text", "text": part.strip()})
122
- elif isinstance(part, str) and part != "<image>":
123
- content.append({"type": "text", "text": part})
124
- logger.debug(f"{content=}")
125
- return content
126
-
127
- def process_new_user_message(message: dict) -> list[dict]:
128
- if not message["files"]:
129
- return [{"type": "text", "text": message["text"]}]
130
-
131
- if message["files"][0].endswith(".mp4"):
132
- return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
133
-
134
- if "<image>" in message["text"]:
135
- return process_interleaved_images(message)
136
-
137
- return [
138
- {"type": "text", "text": message["text"]},
139
- *[{"type": "image", "url": path} for path in message["files"]],
140
- ]
141
-
142
- def process_history(history: list[dict]) -> list[dict]:
143
- messages = []
144
- current_user_content: list[dict] = []
145
- for item in history:
146
- if item["role"] == "assistant":
147
- if current_user_content:
148
- messages.append({"role": "user", "content": current_user_content})
149
- current_user_content = []
150
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
151
  else:
152
- content = item["content"]
153
- if isinstance(content, str):
154
- current_user_content.append({"type": "text", "text": content})
155
- else:
156
- current_user_content.append({"type": "image", "url": content[0]})
157
- return messages
158
-
159
- @spaces.GPU(duration=90)
160
- def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
161
- if not validate_media_constraints(message, history):
162
- yield ""
163
- return
164
-
165
- messages = []
166
- if system_prompt:
167
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
168
- messages.extend(process_history(history))
169
- messages.append({"role": "user", "content": process_new_user_message(message)})
170
-
171
- inputs = processor.apply_chat_template(
172
- messages,
173
- add_generation_prompt=True,
174
- tokenize=True,
175
- return_dict=True,
176
- return_tensors="pt",
177
- ).to(device=model.device, dtype=torch.bfloat16)
178
-
179
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
180
- generate_kwargs = dict(
181
- inputs,
182
- streamer=streamer,
183
- max_new_tokens=max_new_tokens,
184
- )
185
- t = Thread(target=model.generate, kwargs=generate_kwargs)
186
- t.start()
187
-
188
- output = ""
189
- for delta in streamer:
190
- output += delta
191
- yield output
192
-
193
- examples = [
194
- [
195
- {
196
- "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
197
- "files": [],
198
- }
199
- ],
200
- [
201
- {
202
- "text": "Write the matplotlib code to generate the same bar chart.",
203
- "files": ["assets/additional-examples/barchart.png"],
204
- }
205
- ],
206
- [
207
- {
208
- "text": "What is odd about this video?",
209
- "files": ["assets/additional-examples/tmp.mp4"],
210
- }
211
- ],
212
- [
213
- {
214
- "text": "I already have this supplement <image> and I want to buy this one <image>. Any warnings I should know about?",
215
- "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
216
- }
217
- ],
218
- [
219
- {
220
- "text": "Write a poem inspired by the visual elements of the images.",
221
- "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
222
- }
223
- ],
224
- [
225
- {
226
- "text": "Compose a short musical piece inspired by the visual elements of the images.",
227
- "files": [
228
- "assets/sample-images/07-1.png",
229
- "assets/sample-images/07-2.png",
230
- "assets/sample-images/07-3.png",
231
- "assets/sample-images/07-4.png",
232
- ],
233
- }
234
- ],
235
- [
236
- {
237
- "text": "Write a short story about what might have happened in this house.",
238
- "files": ["assets/sample-images/08.png"],
239
- }
240
- ],
241
- [
242
- {
243
- "text": "Create a short story based on the sequence of images.",
244
- "files": [
245
- "examples/09-1.png",
246
- "examples/09-2.png",
247
- "examples/09-3.png",
248
- "examples/09-4.png",
249
- "examples/09-5.png",
250
- ],
251
- }
252
- ],
253
- [
254
- {
255
- "text": "Describe the creatures that would live in this world.",
256
- "files": ["assets/sample-images/10.png"],
257
- }
258
- ],
259
- [
260
- {
261
- "text": "Read text in the image.",
262
- "files": ["assets/additional-examples/1.png"],
263
- }
264
- ],
265
- [
266
- {
267
- "text": "When is this ticket dated and how much did it cost?",
268
- "files": ["assets/additional-examples/2.png"],
269
- }
270
- ],
271
- [
272
- {
273
- "text": "Read the text in the image into markdown.",
274
- "files": ["assets/additional-examples/3.png"],
275
- }
276
- ],
277
- [
278
- {
279
- "text": "Evaluate this integral.",
280
- "files": ["assets/additional-examples/4.png"],
281
- }
282
- ],
283
- [
284
- {
285
- "text": "caption this image",
286
- "files": ["assets/sample-images/01.png"],
287
- }
288
- ],
289
- [
290
- {
291
- "text": "What's the sign says?",
292
- "files": ["assets/sample-images/02.png"],
293
- }
294
- ],
295
- [
296
- {
297
- "text": "Compare and contrast the two images.",
298
- "files": ["assets/sample-images/03.png"],
299
- }
300
- ],
301
- [
302
- {
303
- "text": "List all the objects in the image and their colors.",
304
- "files": ["assets/sample-images/04.png"],
305
- }
306
- ],
307
- [
308
- {
309
- "text": "Describe the atmosphere of the scene.",
310
- "files": ["assets/sample-images/05.png"],
311
- }
312
- ],
313
- ]
314
-
315
- DESCRIPTION = """\
316
- <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
317
-
318
- This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
319
- You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input. For videos, up to 5 frames will be extracted and processed.
320
- """
321
 
 
 
 
322
  demo = gr.ChatInterface(
323
- fn=run,
324
- type="messages",
325
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
326
- textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
327
- multimodal=True,
328
  additional_inputs=[
329
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
330
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  ],
332
- stop_btn=False,
333
- title="Gemma 3 12B IT",
334
- description=DESCRIPTION,
335
- examples=examples,
336
- run_examples_on_click=False,
337
  cache_examples=False,
338
- css_paths="style.css",
339
- delete_cache=(1800, 1800),
 
 
 
 
 
 
340
  )
341
 
342
  if __name__ == "__main__":
343
- demo.launch(debug=True)
 
1
+ """
2
+ app.py
3
 
4
+ This demo builds a Multimodal OCR Granite Vision interface using:
5
+ - @rag: retrieval‐augmented generation for PDF and image documents (via LightRAG)
6
+ - @granite: image understanding with Granite Vision
7
+ - @video-infer: video understanding by downsampling frames and processing each with Granite Vision
8
+
9
+ Make sure the required Granite models and dependencies (Gradio, Transformers, etc.) are installed.
10
+ """
11
 
12
+ import os
13
+ import random
14
+ import uuid
15
+ import time
16
  import cv2
17
+ import numpy as np
 
18
  import torch
 
19
  from PIL import Image
20
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ from transformers import AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForCausalLM
23
+ from transformers.image_utils import load_image
 
 
 
 
 
 
 
 
 
24
 
25
+ # Import the LightRAG class (which internally uses Granite embedding and generation models)
26
+ from sandbox.light_rag.light_rag import LightRAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # ------------------------------
29
+ # Utility and device setup
30
+ # ------------------------------
31
+ def get_device():
32
+ if torch.backends.mps.is_available():
33
+ return "mps" # macOS GPU
34
+ elif torch.cuda.is_available():
35
+ return "cuda"
36
+ else:
37
+ return "cpu"
38
+
39
+ device = get_device()
40
+
41
+ # ------------------------------
42
+ # Generation parameter constants
43
+ # ------------------------------
44
+ MAX_NEW_TOKENS = 1024
45
+ TEMPERATURE = 0.7
46
+ TOP_P = 0.85
47
+ TOP_K = 50
48
+ REPETITION_PENALTY = 1.05
49
+
50
+ # ------------------------------
51
+ # Load Granite Vision model for image processing (@granite and video)
52
+ # ------------------------------
53
+ VISION_MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
54
+ vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID)
55
+ vision_model = AutoModelForVision2Seq.from_pretrained(VISION_MODEL_ID, device_map="auto").to(device)
56
+
57
+ # ------------------------------
58
+ # Initialize the LightRAG pipeline for text-only or document (PDF/image) RAG (@rag)
59
+ # ------------------------------
60
+ rag_config = {
61
+ "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
62
+ "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
63
+ "milvus_collection_name": "granite_vision_text_milvus",
64
+ "milvus_db_path": "milvus.db", # adjust this path as needed
65
+ }
66
+ light_rag = LightRAG(rag_config)
67
+
68
+ # ------------------------------
69
+ # Video downsampling helper
70
+ # ------------------------------
71
+ def downsample_video(video_path):
72
+ """
73
+ Downsamples the video to 10 evenly spaced frames.
74
+ Returns a list of tuples: (PIL image, timestamp in seconds)
75
+ """
76
  vidcap = cv2.VideoCapture(video_path)
 
77
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
 
 
 
 
 
 
79
  frames = []
80
+ # Sample 10 evenly spaced frame indices
81
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
82
+ for i in frame_indices:
83
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
84
+ success, frame = vidcap.read()
85
  if success:
86
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
+ pil_image = Image.fromarray(frame)
88
  timestamp = round(i / fps, 2)
89
  frames.append((pil_image, timestamp))
 
90
  vidcap.release()
91
  return frames
92
 
93
+ # ------------------------------
94
+ # Command processing functions
95
+ # ------------------------------
96
+ def process_rag(query, file_path=None):
97
+ """
98
+ Process @rag command using the LightRAG pipeline.
99
+ Optionally, if a file is provided (e.g. PDF or image), one might extract text from it.
100
+ Here we simply use the query for retrieval-augmented generation.
101
+ """
102
+ context = light_rag.search(query, top_n=5)
103
+ answer, prompt = light_rag.generate(query, context)
104
+ return answer
105
+
106
+ def process_granite(query, image: Image.Image):
107
+ """
108
+ Process @granite command:
109
+ Build a simple prompt from the image and the query then run the Granite Vision model.
110
+ """
111
+ # Here we build a conversation with a single user turn.
112
+ conversation = [{"role": "user", "content": query}]
113
+ inputs = vision_processor.apply_chat_template(
114
+ conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
115
+ ).to(device)
116
+ generate_kwargs = {
117
+ "max_new_tokens": MAX_NEW_TOKENS,
118
+ "do_sample": True,
119
+ "top_p": TOP_P,
120
+ "top_k": TOP_K,
121
+ "temperature": TEMPERATURE,
122
+ "repetition_penalty": REPETITION_PENALTY,
123
+ }
124
+ output = vision_model.generate(**inputs, **generate_kwargs)
125
+ result = vision_processor.decode(output[0], skip_special_tokens=True)
126
+ return result.strip()
127
+
128
+ def process_video(query, video_path):
129
+ """
130
+ Process @video-infer command:
131
+ Downsample the video, process each frame with the Granite Vision model, and combine the results.
132
+ """
133
  frames = downsample_video(video_path)
134
+ descriptions = []
135
+ for image, timestamp in frames:
136
+ desc = process_granite(query, image)
137
+ descriptions.append(f"At {timestamp}s: {desc}")
138
+ return "\n".join(descriptions)
139
+
140
+ # ------------------------------
141
+ # Main function to handle input and dispatch based on command
142
+ # ------------------------------
143
+ def generate_response(input_dict, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
144
+ """
145
+ Based on the query prefix, this function calls:
146
+ - process_rag for @rag
147
+ - process_granite for @granite
148
+ - process_video for @video-infer
149
+ If no special command is provided, it defaults to text-only generation via LightRAG.
150
+ """
151
+ text = input_dict["text"]
152
+ files = input_dict.get("files", [])
153
+ lower_text = text.strip().lower()
154
+
155
+ if lower_text.startswith("@rag"):
156
+ query = text[len("@rag"):].strip()
157
+ file_path = files[0] if files else None # Optionally process the provided file
158
+ answer = process_rag(query, file_path)
159
+ return answer
160
+
161
+ elif lower_text.startswith("@granite"):
162
+ query = text[len("@granite"):].strip()
163
+ if files:
164
+ # Assume first file is an image
165
+ image = load_image(files[0])
166
+ result = process_granite(query, image)
167
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
+ return "No image file provided for @granite command."
170
+
171
+ elif lower_text.startswith("@video-infer"):
172
+ query = text[len("@video-infer"):].strip()
173
+ if files:
174
+ video_path = files[0] # Assume first file is a video
175
+ result = process_video(query, video_path)
176
+ return result
177
+ else:
178
+ return "No video file provided for @video-infer command."
179
+
180
+ else:
181
+ # Default: text-only generation using LightRAG
182
+ answer, prompt = light_rag.generate(text, context=[])
183
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ # ------------------------------
186
+ # Build the Gradio interface using a multimodal textbox
187
+ # ------------------------------
188
  demo = gr.ChatInterface(
189
+ fn=generate_response,
 
 
 
 
190
  additional_inputs=[
191
+ gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=MAX_NEW_TOKENS),
192
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=TEMPERATURE),
193
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=TOP_P),
194
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=TOP_K),
195
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=REPETITION_PENALTY),
196
+ ],
197
+ textbox=gr.MultimodalTextbox(
198
+ label="Query Input",
199
+ file_types=["image", "pdf", "video"],
200
+ file_count="multiple",
201
+ placeholder="Enter your query starting with @rag, @granite, or @video-infer",
202
+ ),
203
+ examples=[
204
+ [{"text": "@rag What was the revenue growth in 2020?"}],
205
+ [{"text": "@granite Describe the content of this image", "files": ["example_image.png"]}],
206
+ [{"text": "@video-infer Summarize the event shown in the video", "files": ["example_video.mp4"]}],
207
  ],
 
 
 
 
 
208
  cache_examples=False,
209
+ type="messages",
210
+ description=(
211
+ "### Multimodal OCR Granite Vision\n"
212
+ "Use **@rag** for PDF/image RAG, **@granite** for image questions, and **@video-infer** for video understanding."
213
+ ),
214
+ fill_height=True,
215
+ stop_btn="Stop Generation",
216
+ theme="default",
217
  )
218
 
219
  if __name__ == "__main__":
220
+ demo.queue(max_size=20).launch()