prithivMLmods commited on
Commit
218ebfb
·
verified ·
1 Parent(s): aaec4ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -140
app.py CHANGED
@@ -1,36 +1,44 @@
1
  """
2
  app.py
3
 
4
- This demo builds a Multimodal OCR Granite Vision interface using:
5
- - @rag: retrieval‐augmented generation for PDF and image documents (via LightRAG)
6
- - @granite: image understanding with Granite Vision
7
- - @video-infer: video understanding by downsampling frames and processing each with Granite Vision
 
8
 
9
- Make sure the required Granite models and dependencies (Gradio, Transformers, etc.) are installed.
10
  """
11
 
12
  import os
13
- import random
14
- import uuid
15
  import time
16
- import cv2
17
- import numpy as np
 
 
 
 
 
18
  import torch
 
 
19
  from PIL import Image
20
  import gradio as gr
21
 
22
- from transformers import AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForCausalLM
23
- from transformers.image_utils import load_image
 
 
 
 
24
 
25
- # Import the LightRAG class (which internally uses Granite embedding and generation models)
26
- from sandbox.light_rag.light_rag import LightRAG
 
27
 
28
- # ------------------------------
29
- # Utility and device setup
30
- # ------------------------------
31
  def get_device():
32
  if torch.backends.mps.is_available():
33
- return "mps" # macOS GPU
34
  elif torch.cuda.is_available():
35
  return "cuda"
36
  else:
@@ -38,182 +46,284 @@ def get_device():
38
 
39
  device = get_device()
40
 
41
- # ------------------------------
42
- # Generation parameter constants
43
- # ------------------------------
44
- MAX_NEW_TOKENS = 1024
45
- TEMPERATURE = 0.7
46
- TOP_P = 0.85
47
- TOP_K = 50
48
- REPETITION_PENALTY = 1.05
49
-
50
- # ------------------------------
51
- # Load Granite Vision model for image processing (@granite and video)
52
- # ------------------------------
53
- VISION_MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
54
- vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID)
55
- vision_model = AutoModelForVision2Seq.from_pretrained(VISION_MODEL_ID, device_map="auto").to(device)
56
-
57
- # ------------------------------
58
- # Initialize the LightRAG pipeline for text-only or document (PDF/image) RAG (@rag)
59
- # ------------------------------
60
- rag_config = {
61
- "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
62
- "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
63
- "milvus_collection_name": "granite_vision_text_milvus",
64
- "milvus_db_path": "milvus.db", # adjust this path as needed
65
- }
66
- light_rag = LightRAG(rag_config)
67
 
68
- # ------------------------------
69
- # Video downsampling helper
70
- # ------------------------------
71
  def downsample_video(video_path):
72
  """
73
- Downsamples the video to 10 evenly spaced frames.
74
- Returns a list of tuples: (PIL image, timestamp in seconds)
75
  """
76
  vidcap = cv2.VideoCapture(video_path)
77
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
78
  fps = vidcap.get(cv2.CAP_PROP_FPS)
79
  frames = []
80
- # Sample 10 evenly spaced frame indices
81
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
82
  for i in frame_indices:
83
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
84
- success, frame = vidcap.read()
85
  if success:
86
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
- pil_image = Image.fromarray(frame)
88
  timestamp = round(i / fps, 2)
89
  frames.append((pil_image, timestamp))
90
  vidcap.release()
91
  return frames
92
 
93
- # ------------------------------
94
- # Command processing functions
95
- # ------------------------------
96
- def process_rag(query, file_path=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
- Process @rag command using the LightRAG pipeline.
99
- Optionally, if a file is provided (e.g. PDF or image), one might extract text from it.
100
- Here we simply use the query for retrieval-augmented generation.
101
  """
102
- context = light_rag.search(query, top_n=5)
103
- answer, prompt = light_rag.generate(query, context)
104
- return answer
 
105
 
106
- def process_granite(query, image: Image.Image):
107
  """
108
- Process @granite command:
109
- Build a simple prompt from the image and the query then run the Granite Vision model.
110
  """
111
- # Here we build a conversation with a single user turn.
112
- conversation = [{"role": "user", "content": query}]
113
- inputs = vision_processor.apply_chat_template(
114
  conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
115
  ).to(device)
116
- generate_kwargs = {
117
- "max_new_tokens": MAX_NEW_TOKENS,
118
- "do_sample": True,
119
- "top_p": TOP_P,
120
- "top_k": TOP_K,
121
- "temperature": TEMPERATURE,
122
- "repetition_penalty": REPETITION_PENALTY,
123
- }
124
- output = vision_model.generate(**inputs, **generate_kwargs)
125
- result = vision_processor.decode(output[0], skip_special_tokens=True)
126
- return result.strip()
127
-
128
- def process_video(query, video_path):
 
129
  """
130
- Process @video-infer command:
131
- Downsample the video, process each frame with the Granite Vision model, and combine the results.
132
  """
133
  frames = downsample_video(video_path)
134
- descriptions = []
135
- for image, timestamp in frames:
136
- desc = process_granite(query, image)
137
- descriptions.append(f"At {timestamp}s: {desc}")
138
- return "\n".join(descriptions)
139
-
140
- # ------------------------------
141
- # Main function to handle input and dispatch based on command
142
- # ------------------------------
143
- def generate_response(input_dict, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  """
145
- Based on the query prefix, this function calls:
146
- - process_rag for @rag
147
- - process_granite for @granite
148
- - process_video for @video-infer
149
- If no special command is provided, it defaults to text-only generation via LightRAG.
150
  """
151
  text = input_dict["text"]
152
  files = input_dict.get("files", [])
153
  lower_text = text.strip().lower()
154
-
 
 
 
 
155
  if lower_text.startswith("@rag"):
156
  query = text[len("@rag"):].strip()
157
- file_path = files[0] if files else None # Optionally process the provided file
158
- answer = process_rag(query, file_path)
159
- return answer
160
-
 
161
  elif lower_text.startswith("@granite"):
162
- query = text[len("@granite"):].strip()
 
163
  if files:
164
- # Assume first file is an image
165
- image = load_image(files[0])
166
- result = process_granite(query, image)
167
- return result
168
  else:
169
- return "No image file provided for @granite command."
170
-
171
  elif lower_text.startswith("@video-infer"):
172
- query = text[len("@video-infer"):].strip()
 
173
  if files:
174
- video_path = files[0] # Assume first file is a video
175
- result = process_video(query, video_path)
176
- return result
 
177
  else:
178
- return "No video file provided for @video-infer command."
179
-
180
  else:
181
- # Default: text-only generation using LightRAG
182
- answer, prompt = light_rag.generate(text, context=[])
183
- return answer
 
 
 
 
 
 
 
184
 
185
- # ------------------------------
186
- # Build the Gradio interface using a multimodal textbox
187
- # ------------------------------
188
  demo = gr.ChatInterface(
189
- fn=generate_response,
190
  additional_inputs=[
191
- gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=MAX_NEW_TOKENS),
192
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=TEMPERATURE),
193
- gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=TOP_P),
194
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=TOP_K),
195
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=REPETITION_PENALTY),
196
  ],
197
- textbox=gr.MultimodalTextbox(
198
- label="Query Input",
199
- file_types=["image", "pdf", "video"],
200
- file_count="multiple",
201
- placeholder="Enter your query starting with @rag, @granite, or @video-infer",
202
- ),
203
  examples=[
204
- [{"text": "@rag What was the revenue growth in 2020?"}],
205
- [{"text": "@granite Describe the content of this image", "files": ["example_image.png"]}],
206
- [{"text": "@video-infer Summarize the event shown in the video", "files": ["example_video.mp4"]}],
 
207
  ],
208
  cache_examples=False,
209
  type="messages",
210
  description=(
211
- "### Multimodal OCR Granite Vision\n"
212
- "Use **@rag** for PDF/image RAG, **@granite** for image questions, and **@video-infer** for video understanding."
 
 
 
 
213
  ),
214
  fill_height=True,
 
 
 
 
 
 
215
  stop_btn="Stop Generation",
216
- theme="default",
217
  )
218
 
219
  if __name__ == "__main__":
 
1
  """
2
  app.py
3
 
4
+ A unified Gradio chat application for Multimodal OCR Granite Vision.
5
+ Commands (enter these as a prefix in the text input):
6
+ - @rag: For retrieval‐augmented generation (e.g. PDF or text-based queries).
7
+ - @granite: For image understanding.
8
+ - @video-infer: For video understanding (video is downsampled into frames).
9
 
10
+ The app uses gr.MultimodalTextbox to support text input together with file uploads.
11
  """
12
 
13
  import os
 
 
14
  import time
15
+ import uuid
16
+ import random
17
+ import logging
18
+ from threading import Thread
19
+ from pathlib import Path
20
+ from datetime import datetime, timezone
21
+
22
  import torch
23
+ import numpy as np
24
+ import cv2
25
  from PIL import Image
26
  import gradio as gr
27
 
28
+ from transformers import (
29
+ AutoModelForCausalLM,
30
+ AutoTokenizer,
31
+ AutoProcessor,
32
+ AutoModelForVision2Seq,
33
+ )
34
 
35
+ # ---------------------------
36
+ # Utility functions and setup
37
+ # ---------------------------
38
 
 
 
 
39
  def get_device():
40
  if torch.backends.mps.is_available():
41
+ return "mps" # mac GPU
42
  elif torch.cuda.is_available():
43
  return "cuda"
44
  else:
 
46
 
47
  device = get_device()
48
 
49
+ logging.basicConfig(
50
+ level=logging.INFO,
51
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
52
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
54
  def downsample_video(video_path):
55
  """
56
+ Downsamples the video into 10 evenly spaced frames.
57
+ Returns a list of (PIL Image, timestamp in seconds) tuples.
58
  """
59
  vidcap = cv2.VideoCapture(video_path)
60
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
61
  fps = vidcap.get(cv2.CAP_PROP_FPS)
62
  frames = []
 
63
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
64
  for i in frame_indices:
65
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
66
+ success, image = vidcap.read()
67
  if success:
68
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
+ pil_image = Image.fromarray(image)
70
  timestamp = round(i / fps, 2)
71
  frames.append((pil_image, timestamp))
72
  vidcap.release()
73
  return frames
74
 
75
+ # ---------------------------
76
+ # HF Embedding and LLM classes
77
+ # ---------------------------
78
+
79
+ class HFEmbedding:
80
+ def __init__(self, model_id: str):
81
+ self.model_name = model_id
82
+ logging.info(f"Loading embeddings model from: {self.model_name}")
83
+ # Using langchain_huggingface for embeddings
84
+ from langchain_huggingface import HuggingFaceEmbeddings # ensure installed
85
+ # For simplicity, force CPU (adjust if needed)
86
+ self.embeddings_service = HuggingFaceEmbeddings(
87
+ model_name=self.model_name,
88
+ model_kwargs={"device": "cpu"},
89
+ )
90
+
91
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
92
+ return self.embeddings_service.embed_documents(texts)
93
+
94
+ def embed_query(self, text: str) -> list[float]:
95
+ return self.embed_documents([text])[0]
96
+
97
+ class HFLLM:
98
+ def __init__(self, model_name: str):
99
+ self.device = device
100
+ self.model_name = model_name
101
+ logging.info("Loading HF language model...")
102
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
103
+ self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
104
+
105
+ def generate(self, prompt: str) -> list:
106
+ # Tokenize prompt and generate text
107
+ model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
108
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
109
+ generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
110
+ # Extract answer assuming a marker in the generated text
111
+ response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}]
112
+ return response
113
+
114
+ # ---------------------------
115
+ # LightRAG: Retrieval-Augmented Generation (Dummy)
116
+ # ---------------------------
117
+
118
+ class LightRAG:
119
+ def __init__(self, config: dict):
120
+ self.config = config
121
+ # Load generation and embedding models immediately (or lazy load as needed)
122
+ self.gen_model = HFLLM(config['generation_model_id'])
123
+ self._embedding_model = HFEmbedding(config['embedding_model_id'])
124
+
125
+ def search(self, query: str, top_n: int = 5) -> list:
126
+ # Dummy retrieval: In practice, integrate with a vector store
127
+ from langchain_core.documents import Document # ensure langchain_core is installed
128
+ dummy_doc = Document(
129
+ page_content="Dummy context for query: " + query,
130
+ metadata={"type": "text"}
131
+ )
132
+ return [dummy_doc]
133
+
134
+ def generate(self, query, context=None):
135
+ if context is None:
136
+ context = []
137
+ # Build prompt by concatenating retrieved context with the query.
138
+ prompt = ""
139
+ for doc in context:
140
+ prompt += doc.page_content + "\n"
141
+ prompt += "\nQuestion: " + query + "\nAnswer:"
142
+ results = self.gen_model.generate(prompt)
143
+ answer = results[0]["answer"]
144
+ return answer, prompt
145
+
146
+ # Global configuration for LightRAG
147
+ rag_config = {
148
+ "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
149
+ "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
150
+ }
151
+ light_rag = LightRAG(rag_config)
152
+
153
+ # ---------------------------
154
+ # Granite Vision functions (for image and video)
155
+ # ---------------------------
156
+
157
+ # Set the Granite Vision model ID (adjust version as needed)
158
+ GRANITE_MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
159
+ granite_processor = None
160
+ granite_model = None
161
+
162
+ def load_granite_model():
163
+ """Lazy load the Granite vision processor and model."""
164
+ global granite_processor, granite_model
165
+ if granite_processor is None or granite_model is None:
166
+ granite_processor = AutoProcessor.from_pretrained(GRANITE_MODEL_ID)
167
+ granite_model = AutoModelForVision2Seq.from_pretrained(GRANITE_MODEL_ID, device_map="auto").to(device)
168
+ return granite_processor, granite_model
169
+
170
+ def create_single_turn(image, text):
171
  """
172
+ Creates a single-turn conversation message.
173
+ If an image is provided, it is added along with the text.
 
174
  """
175
+ if image is None:
176
+ return {"role": "user", "content": [{"type": "text", "text": text}]}
177
+ else:
178
+ return {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}
179
 
180
+ def generate_granite(image, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05):
181
  """
182
+ Generates a response from the Granite Vision model given an image and prompt.
 
183
  """
184
+ processor, model = load_granite_model()
185
+ conversation = [create_single_turn(image, prompt_text)]
186
+ inputs = processor.apply_chat_template(
187
  conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
188
  ).to(device)
189
+ output = model.generate(
190
+ **inputs,
191
+ max_new_tokens=max_new_tokens,
192
+ do_sample=True,
193
+ top_p=top_p,
194
+ top_k=top_k,
195
+ temperature=temperature,
196
+ repetition_penalty=repetition_penalty,
197
+ )
198
+ decoded = processor.decode(output[0], skip_special_tokens=True)
199
+ parts = decoded.strip().split("<|assistant|>")
200
+ return parts[-1].strip()
201
+
202
+ def generate_video_infer(video_path, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05):
203
  """
204
+ Processes a video file by downsampling frames and sending them along with a prompt
205
+ to the Granite Vision model.
206
  """
207
  frames = downsample_video(video_path)
208
+ conversation_content = []
209
+ for img, ts in frames:
210
+ conversation_content.append({"type": "text", "text": f"Frame at {ts} sec:"})
211
+ conversation_content.append({"type": "image", "image": img})
212
+ conversation_content.append({"type": "text", "text": prompt_text})
213
+ conversation = [{"role": "user", "content": conversation_content}]
214
+ processor, model = load_granite_model()
215
+ inputs = processor.apply_chat_template(
216
+ conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
217
+ ).to(device)
218
+ output = model.generate(
219
+ **inputs,
220
+ max_new_tokens=max_new_tokens,
221
+ do_sample=True,
222
+ top_p=top_p,
223
+ top_k=top_k,
224
+ temperature=temperature,
225
+ repetition_penalty=repetition_penalty,
226
+ )
227
+ decoded = processor.decode(output[0], skip_special_tokens=True)
228
+ parts = decoded.strip().split("<|assistant|>")
229
+ return parts[-1].strip()
230
+
231
+ # ---------------------------
232
+ # Unified generation function for ChatInterface
233
+ # ---------------------------
234
+
235
+ def generate(input_dict: dict, chat_history: list[dict],
236
+ max_new_tokens: int, temperature: float,
237
+ top_p: float, top_k: int, repetition_penalty: float):
238
  """
239
+ Chat function that inspects the input text for special commands and routes:
240
+ - @rag: Uses the RAG pipeline.
241
+ - @granite: Uses Granite Vision for image understanding.
242
+ - @video-infer: Uses Granite Vision for video processing.
 
243
  """
244
  text = input_dict["text"]
245
  files = input_dict.get("files", [])
246
  lower_text = text.strip().lower()
247
+
248
+ # Optionally yield a progress message
249
+ yield "Processing your request..."
250
+ time.sleep(1) # simulate processing delay
251
+
252
  if lower_text.startswith("@rag"):
253
  query = text[len("@rag"):].strip()
254
+ logging.info(f"@rag command: {query}")
255
+ context = light_rag.search(query)
256
+ answer, _ = light_rag.generate(query, context)
257
+ yield answer
258
+
259
  elif lower_text.startswith("@granite"):
260
+ prompt_text = text[len("@granite"):].strip()
261
+ logging.info(f"@granite command: {prompt_text}")
262
  if files:
263
+ # Expecting an image file (as a PIL image)
264
+ image = files[0]
265
+ answer = generate_granite(image, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
266
+ yield answer
267
  else:
268
+ yield "No image provided for @granite command."
269
+
270
  elif lower_text.startswith("@video-infer"):
271
+ prompt_text = text[len("@video-infer"):].strip()
272
+ logging.info(f"@video-infer command: {prompt_text}")
273
  if files:
274
+ # Expecting a video file (the file path)
275
+ video_path = files[0]
276
+ answer = generate_video_infer(video_path, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
277
+ yield answer
278
  else:
279
+ yield "No video provided for @video-infer command."
280
+
281
  else:
282
+ # Default behavior: use RAG pipeline for text query.
283
+ query = text.strip()
284
+ logging.info(f"Default text query: {query}")
285
+ context = light_rag.search(query)
286
+ answer, _ = light_rag.generate(query, context)
287
+ yield answer
288
+
289
+ # ---------------------------
290
+ # Gradio ChatInterface using MultimodalTextbox
291
+ # ---------------------------
292
 
 
 
 
293
  demo = gr.ChatInterface(
294
+ fn=generate,
295
  additional_inputs=[
296
+ gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=1024),
297
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
298
+ gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.1, value=0.85),
299
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
300
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.05),
301
  ],
 
 
 
 
 
 
302
  examples=[
303
+ # Examples show how to use the command prefixes.
304
+ [{"text": "@rag What models are available in Watsonx?"}],
305
+ [{"text": "@granite Describe the image", "files": [str(Path("examples") / "sample_image.png")]}],
306
+ [{"text": "@video-infer Summarize the event in the video", "files": [str(Path("examples") / "sample_video.mp4")]}],
307
  ],
308
  cache_examples=False,
309
  type="messages",
310
  description=(
311
+ "# **Multimodal OCR Granite Vision**\n\n"
312
+ "Enter a command in the text input (with optional file uploads) using one of the following prefixes:\n\n"
313
+ "- **@rag**: For retrieval-augmented generation (e.g. PDFs, documents).\n"
314
+ "- **@granite**: For image understanding using Granite Vision.\n"
315
+ "- **@video-infer**: For video understanding (video is downsampled into frames).\n\n"
316
+ "For example:\n```\n@rag What is the revenue trend?\n```\n```\n@granite Describe this image\n```\n```\n@video-infer Summarize the event in this video\n```"
317
  ),
318
  fill_height=True,
319
+ textbox=gr.MultimodalTextbox(
320
+ label="Query Input",
321
+ file_types=["image", "video", "pdf"],
322
+ file_count="multiple",
323
+ placeholder="@rag, @granite, or @video-infer followed by your prompt"
324
+ ),
325
  stop_btn="Stop Generation",
326
+ multimodal=True,
327
  )
328
 
329
  if __name__ == "__main__":