prithivMLmods commited on
Commit
4425d90
·
verified ·
1 Parent(s): 896773c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -324
app.py CHANGED
@@ -1,332 +1,210 @@
1
- """
2
- app.py
3
-
4
- A unified Gradio chat application for Multimodal OCR Granite Vision.
5
- Commands (enter these as a prefix in the text input):
6
- - @rag: For retrieval‐augmented generation (e.g. PDFs or text-based queries).
7
- - @granite: For image understanding.
8
- - @video-infer: For video understanding (video is downsampled into frames).
9
-
10
- The app uses gr.MultimodalTextbox to support text input together with file uploads.
11
- """
12
-
13
- import os
14
- import time
15
- import uuid
16
- import random
17
  import spaces
18
- import logging
19
- from threading import Thread
20
- from pathlib import Path
21
- from datetime import datetime, timezone
22
-
23
  import torch
24
- import numpy as np
25
- import cv2
26
  from PIL import Image
27
- import gradio as gr
28
-
29
- from transformers import (
30
- AutoModelForCausalLM,
31
- AutoTokenizer,
32
- AutoProcessor,
33
- AutoModelForVision2Seq,
34
- )
35
-
36
- # ---------------------------
37
- # Utility functions and setup
38
- # ---------------------------
39
-
40
- def get_device():
41
- if torch.backends.mps.is_available():
42
- return "mps" # mac GPU
43
- elif torch.cuda.is_available():
44
- return "cuda"
45
- else:
46
- return "cpu"
47
-
48
- device = get_device()
49
-
50
- logging.basicConfig(
51
- level=logging.INFO,
52
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
53
- )
54
-
55
- def downsample_video(video_path):
56
- """
57
- Downsamples the video into 10 evenly spaced frames.
58
- Returns a list of (PIL Image, timestamp in seconds) tuples.
59
- """
60
- vidcap = cv2.VideoCapture(video_path)
61
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
62
- fps = vidcap.get(cv2.CAP_PROP_FPS)
63
- frames = []
64
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
65
- for i in frame_indices:
66
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
67
- success, image = vidcap.read()
68
- if success:
69
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
70
- pil_image = Image.fromarray(image)
71
- timestamp = round(i / fps, 2)
72
- frames.append((pil_image, timestamp))
73
- vidcap.release()
74
- return frames
75
-
76
- # ---------------------------
77
- # HF Embedding and LLM classes
78
- # ---------------------------
79
-
80
- class HFEmbedding:
81
- def __init__(self, model_id: str):
82
- self.model_name = model_id
83
- logging.info(f"Loading embeddings model from: {self.model_name}")
84
- # Using langchain_huggingface for embeddings
85
- from langchain_huggingface import HuggingFaceEmbeddings # ensure installed
86
- # For simplicity, force CPU (adjust if needed)
87
- self.embeddings_service = HuggingFaceEmbeddings(
88
- model_name=self.model_name,
89
- model_kwargs={"device": "cpu"},
90
- )
91
-
92
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
93
- return self.embeddings_service.embed_documents(texts)
94
-
95
- def embed_query(self, text: str) -> list[float]:
96
- return self.embed_documents([text])[0]
97
-
98
- class HFLLM:
99
- def __init__(self, model_name: str):
100
- self.device = device
101
- self.model_name = model_name
102
- logging.info("Loading HF language model...")
103
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
104
- self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
105
-
106
- def generate(self, prompt: str) -> list:
107
- # Tokenize prompt and generate text
108
- model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
109
- generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
110
- generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
111
- # Extract answer assuming a marker in the generated text
112
- response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}]
113
- return response
114
-
115
- # ---------------------------
116
- # LightRAG: Retrieval-Augmented Generation (Dummy)
117
- # ---------------------------
118
-
119
- class LightRAG:
120
- def __init__(self, config: dict):
121
- self.config = config
122
- # Load generation and embedding models immediately (or lazy load as needed)
123
- self.gen_model = HFLLM(config['generation_model_id'])
124
- self._embedding_model = HFEmbedding(config['embedding_model_id'])
125
-
126
- def search(self, query: str, top_n: int = 5) -> list:
127
- # Dummy retrieval: In practice, integrate with a vector store
128
- from langchain_core.documents import Document # ensure langchain_core is installed
129
- dummy_doc = Document(
130
- page_content="Dummy context for query: " + query,
131
- metadata={"type": "text"}
132
- )
133
- return [dummy_doc]
134
-
135
- def generate(self, query, context=None):
136
- if context is None:
137
- context = []
138
- # Build prompt by concatenating retrieved context with the query.
139
- prompt = ""
140
- for doc in context:
141
- prompt += doc.page_content + "\n"
142
- prompt += "\nQuestion: " + query + "\nAnswer:"
143
- results = self.gen_model.generate(prompt)
144
- answer = results[0]["answer"]
145
- return answer, prompt
146
 
147
- # Global configuration for LightRAG
148
- rag_config = {
149
- "embedding_model_id": "ibm-granite/granite-embedding-125m-english",
150
- "generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  }
152
- light_rag = LightRAG(rag_config)
153
-
154
- # ---------------------------
155
- # Granite Vision functions (for image and video)
156
- # ---------------------------
157
-
158
- # Set the Granite Vision model ID (adjust version as needed)
159
- GRANITE_MODEL_ID = "ibm-granite/granite-vision-3.2-2b"
160
- granite_processor = None
161
- granite_model = None
162
-
163
- def load_granite_model():
164
- """Lazy load the Granite vision processor and model."""
165
- global granite_processor, granite_model
166
- if granite_processor is None or granite_model is None:
167
- granite_processor = AutoProcessor.from_pretrained(GRANITE_MODEL_ID)
168
- # Remove the .to(device) call to avoid moving a model already offloaded via accelerate.
169
- granite_model = AutoModelForVision2Seq.from_pretrained(GRANITE_MODEL_ID, device_map="auto")
170
- return granite_processor, granite_model
171
-
172
- def create_single_turn(image, text):
173
- """
174
- Creates a single-turn conversation message.
175
- If an image is provided, it is added along with the text.
176
- """
177
- if image is None:
178
- return {"role": "user", "content": [{"type": "text", "text": text}]}
179
- else:
180
- return {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}
181
-
182
- def generate_granite(image, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05):
183
- """
184
- Generates a response from the Granite Vision model given an image and prompt.
185
- """
186
- processor, model = load_granite_model()
187
- conversation = [create_single_turn(image, prompt_text)]
188
- inputs = processor.apply_chat_template(
189
- conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
190
- ).to(device)
191
- output = model.generate(
192
- **inputs,
193
- max_new_tokens=max_new_tokens,
194
- do_sample=True,
195
- top_p=top_p,
196
- top_k=top_k,
197
- temperature=temperature,
198
- repetition_penalty=repetition_penalty,
199
- )
200
- decoded = processor.decode(output[0], skip_special_tokens=True)
201
- parts = decoded.strip().split("<|assistant|>")
202
- return parts[-1].strip()
203
-
204
- def generate_video_infer(video_path, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05):
205
- """
206
- Processes a video file by downsampling frames and sending them along with a prompt
207
- to the Granite Vision model.
208
- """
209
- frames = downsample_video(video_path)
210
- conversation_content = []
211
- for img, ts in frames:
212
- conversation_content.append({"type": "text", "text": f"Frame at {ts} sec:"})
213
- conversation_content.append({"type": "image", "image": img})
214
- conversation_content.append({"type": "text", "text": prompt_text})
215
- conversation = [{"role": "user", "content": conversation_content}]
216
- processor, model = load_granite_model()
217
- inputs = processor.apply_chat_template(
218
- conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
219
- ).to(device)
220
- output = model.generate(
221
- **inputs,
222
- max_new_tokens=max_new_tokens,
223
- do_sample=True,
224
- top_p=top_p,
225
- top_k=top_k,
226
- temperature=temperature,
227
- repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
- decoded = processor.decode(output[0], skip_special_tokens=True)
230
- parts = decoded.strip().split("<|assistant|>")
231
- return parts[-1].strip()
232
-
233
- # ---------------------------
234
- # Unified generation function for ChatInterface
235
- # ---------------------------
236
- @spaces.GPU
237
- def generate(input_dict: dict, chat_history: list[dict],
238
- max_new_tokens: int, temperature: float,
239
- top_p: float, top_k: int, repetition_penalty: float):
240
- """
241
- Chat function that inspects the input text for special commands and routes:
242
- - @rag: Uses the RAG pipeline.
243
- - @granite: Uses Granite Vision for image understanding.
244
- - @video-infer: Uses Granite Vision for video processing.
245
- """
246
- text = input_dict["text"]
247
- files = input_dict.get("files", [])
248
- lower_text = text.strip().lower()
249
-
250
- # Optionally yield a progress message
251
- yield "Processing your request..."
252
- time.sleep(1) # simulate processing delay
253
-
254
- if lower_text.startswith("@rag"):
255
- query = text[len("@rag"):].strip()
256
- logging.info(f"@rag command: {query}")
257
- context = light_rag.search(query)
258
- answer, _ = light_rag.generate(query, context)
259
- yield answer
260
-
261
- elif lower_text.startswith("@granite"):
262
- prompt_text = text[len("@granite"):].strip()
263
- logging.info(f"@granite command: {prompt_text}")
264
- if files:
265
- # Expecting an image file (as a PIL image)
266
- image = files[0]
267
- answer = generate_granite(image, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
268
- yield answer
269
- else:
270
- yield "No image provided for @granite command."
271
-
272
- elif lower_text.startswith("@video-infer"):
273
- prompt_text = text[len("@video-infer"):].strip()
274
- logging.info(f"@video-infer command: {prompt_text}")
275
- if files:
276
- # Expecting a video file (the file path)
277
- video_path = files[0]
278
- answer = generate_video_infer(video_path, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
279
- yield answer
280
- else:
281
- yield "No video provided for @video-infer command."
282
-
283
- else:
284
- # Default behavior: use RAG pipeline for text query.
285
- query = text.strip()
286
- logging.info(f"Default text query: {query}")
287
- context = light_rag.search(query)
288
- answer, _ = light_rag.generate(query, context)
289
- yield answer
290
-
291
- # ---------------------------
292
- # Gradio ChatInterface using MultimodalTextbox
293
- # ---------------------------
294
-
295
- demo = gr.ChatInterface(
296
- fn=generate,
297
- additional_inputs=[
298
- gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=1024),
299
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
300
- gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.1, value=0.85),
301
- gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
302
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.05),
303
- ],
304
- examples=[
305
- # Examples show how to use the command prefixes.
306
- [{"text": "@rag What models are available in Watsonx?"}],
307
- [{"text": "@granite Describe the image", "files": [str(Path("examples") / "sample_image.png")]}],
308
- [{"text": "@video-infer Summarize the event in the video", "files": [str(Path("examples") / "sample_video.mp4")]}],
309
- ],
310
- cache_examples=False,
311
- type="messages",
312
- description=(
313
- "# **Multimodal OCR Granite Vision**\n\n"
314
- "Enter a command in the text input (with optional file uploads) using one of the following prefixes:\n\n"
315
- "- **@rag**: For retrieval-augmented generation (e.g. PDFs, documents).\n"
316
- "- **@granite**: For image understanding using Granite Vision.\n"
317
- "- **@video-infer**: For video understanding (video is downsampled into frames).\n\n"
318
- "For example:\n```\n@rag What is the revenue trend?\n```\n```\n@granite Describe this image\n```\n```\n@video-infer Summarize the event in this video\n```"
319
- ),
320
- fill_height=True,
321
- textbox=gr.MultimodalTextbox(
322
- label="Query Input",
323
- file_types=["image", "video", "pdf"],
324
- file_count="multiple",
325
- placeholder="@rag, @granite, or @video-infer followed by your prompt"
326
- ),
327
- stop_btn="Stop Generation",
328
- multimodal=True,
329
- )
330
 
331
  if __name__ == "__main__":
332
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ import gradio as gr
 
 
 
 
3
  import torch
 
 
4
  from PIL import Image
5
+ from diffusers import DiffusionPipeline
6
+ import random
7
+ import uuid
8
+ from typing import Tuple
9
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def save_image(img):
12
+ unique_name = str(uuid.uuid4()) + ".png"
13
+ img.save(unique_name)
14
+ return unique_name
15
+
16
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
17
+ if randomize_seed:
18
+ seed = random.randint(0, MAX_SEED)
19
+ return seed
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
25
+
26
+ base_model = "black-forest-labs/FLUX.1-dev"
27
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
28
+
29
+ lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
30
+ trigger_word = "Super Realism" # Leave trigger_word blank if not used.
31
+
32
+ pipe.load_lora_weights(lora_repo)
33
+ pipe.to("cuda")
34
+
35
+ style_list = [
36
+ {
37
+ "name": "3840 x 2160",
38
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
39
+ },
40
+ {
41
+ "name": "2560 x 1440",
42
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
43
+ },
44
+ {
45
+ "name": "HD+",
46
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
47
+ },
48
+ {
49
+ "name": "Style Zero",
50
+ "prompt": "{prompt}",
51
+ },
52
+ ]
53
+
54
+ styles = {k["name"]: k["prompt"] for k in style_list}
55
+
56
+ DEFAULT_STYLE_NAME = "3840 x 2160"
57
+ STYLE_NAMES = list(styles.keys())
58
+
59
+ def apply_style(style_name: str, positive: str) -> str:
60
+ return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
61
+
62
+ @spaces.GPU(duration=60, enable_queue=True)
63
+ def generate(
64
+ prompt: str,
65
+ seed: int = 0,
66
+ width: int = 1024,
67
+ height: int = 1024,
68
+ guidance_scale: float = 3,
69
+ randomize_seed: bool = False,
70
+ style_name: str = DEFAULT_STYLE_NAME,
71
+ progress=gr.Progress(track_tqdm=True),
72
+ ):
73
+ seed = int(randomize_seed_fn(seed, randomize_seed))
74
+
75
+ positive_prompt = apply_style(style_name, prompt)
76
+
77
+ if trigger_word:
78
+ positive_prompt = f"{trigger_word} {positive_prompt}"
79
+
80
+ images = pipe(
81
+ prompt=positive_prompt,
82
+ width=width,
83
+ height=height,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=28,
86
+ num_images_per_prompt=1,
87
+ output_type="pil",
88
+ ).images
89
+ image_paths = [save_image(img) for img in images]
90
+ print(image_paths)
91
+ return image_paths, seed
92
+
93
+ examples = [
94
+ "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
95
+ "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
96
+ "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
97
+ "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
98
+ ]
99
+
100
+ css = '''
101
+ .gradio-container{max-width: 888px !important}
102
+ h1{text-align:center}
103
+ footer {
104
+ visibility: hidden
105
  }
106
+ .submit-btn {
107
+ background-color: #e34949 !important;
108
+ color: white !important;
109
+ }
110
+ .submit-btn:hover {
111
+ background-color: #ff3b3b !important;
112
+ }
113
+ '''
114
+
115
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
116
+ with gr.Row():
117
+ with gr.Column(scale=1):
118
+ prompt = gr.Text(
119
+ label="Prompt",
120
+ show_label=False,
121
+ max_lines=1,
122
+ placeholder="Enter your prompt",
123
+ container=False,
124
+ )
125
+ run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
126
+
127
+ with gr.Accordion("Advanced options", open=True, visible=True):
128
+ seed = gr.Slider(
129
+ label="Seed",
130
+ minimum=0,
131
+ maximum=MAX_SEED,
132
+ step=1,
133
+ value=0,
134
+ visible=True
135
+ )
136
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
137
+
138
+ with gr.Row(visible=True):
139
+ width = gr.Slider(
140
+ label="Width",
141
+ minimum=512,
142
+ maximum=2048,
143
+ step=64,
144
+ value=768,
145
+ )
146
+ height = gr.Slider(
147
+ label="Height",
148
+ minimum=512,
149
+ maximum=2048,
150
+ step=64,
151
+ value=1024,
152
+ )
153
+
154
+ with gr.Row():
155
+ guidance_scale = gr.Slider(
156
+ label="Guidance Scale",
157
+ minimum=0.1,
158
+ maximum=20.0,
159
+ step=0.1,
160
+ value=3.0,
161
+ )
162
+ num_inference_steps = gr.Slider(
163
+ label="Number of inference steps",
164
+ minimum=1,
165
+ maximum=40,
166
+ step=1,
167
+ value=28,
168
+ )
169
+
170
+ style_selection = gr.Radio(
171
+ show_label=True,
172
+ container=True,
173
+ interactive=True,
174
+ choices=STYLE_NAMES,
175
+ value=DEFAULT_STYLE_NAME,
176
+ label="Quality Style",
177
+ )
178
+
179
+ with gr.Column(scale=2):
180
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
181
+
182
+ gr.Examples(
183
+ examples=examples,
184
+ inputs=prompt,
185
+ outputs=[result, seed],
186
+ fn=generate,
187
+ cache_examples=False,
188
+ )
189
+
190
+ gr.on(
191
+ triggers=[
192
+ prompt.submit,
193
+ run_button.click,
194
+ ],
195
+ fn=generate,
196
+ inputs=[
197
+ prompt,
198
+ seed,
199
+ width,
200
+ height,
201
+ guidance_scale,
202
+ randomize_seed,
203
+ style_selection,
204
+ ],
205
+ outputs=[result, seed],
206
+ api_name="run",
207
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
+ demo.queue(max_size=40).launch()