Spaces:
Running
on
Zero
Running
on
Zero
""" | |
app.py | |
This demo builds a Multimodal OCR Granite Vision interface using: | |
- @rag: retrieval‐augmented generation for PDF and image documents (via LightRAG) | |
- @granite: image understanding with Granite Vision | |
- @video-infer: video understanding by downsampling frames and processing each with Granite Vision | |
Make sure the required Granite models and dependencies (Gradio, Transformers, etc.) are installed. | |
""" | |
import os | |
import random | |
import uuid | |
import time | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForCausalLM | |
from transformers.image_utils import load_image | |
# Import the LightRAG class (which internally uses Granite embedding and generation models) | |
from sandbox.light_rag.light_rag import LightRAG | |
# ------------------------------ | |
# Utility and device setup | |
# ------------------------------ | |
def get_device(): | |
if torch.backends.mps.is_available(): | |
return "mps" # macOS GPU | |
elif torch.cuda.is_available(): | |
return "cuda" | |
else: | |
return "cpu" | |
device = get_device() | |
# ------------------------------ | |
# Generation parameter constants | |
# ------------------------------ | |
MAX_NEW_TOKENS = 1024 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.85 | |
TOP_K = 50 | |
REPETITION_PENALTY = 1.05 | |
# ------------------------------ | |
# Load Granite Vision model for image processing (@granite and video) | |
# ------------------------------ | |
VISION_MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID) | |
vision_model = AutoModelForVision2Seq.from_pretrained(VISION_MODEL_ID, device_map="auto").to(device) | |
# ------------------------------ | |
# Initialize the LightRAG pipeline for text-only or document (PDF/image) RAG (@rag) | |
# ------------------------------ | |
rag_config = { | |
"embedding_model_id": "ibm-granite/granite-embedding-125m-english", | |
"generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
"milvus_collection_name": "granite_vision_text_milvus", | |
"milvus_db_path": "milvus.db", # adjust this path as needed | |
} | |
light_rag = LightRAG(rag_config) | |
# ------------------------------ | |
# Video downsampling helper | |
# ------------------------------ | |
def downsample_video(video_path): | |
""" | |
Downsamples the video to 10 evenly spaced frames. | |
Returns a list of tuples: (PIL image, timestamp in seconds) | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
# Sample 10 evenly spaced frame indices | |
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, frame = vidcap.read() | |
if success: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(frame) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
# ------------------------------ | |
# Command processing functions | |
# ------------------------------ | |
def process_rag(query, file_path=None): | |
""" | |
Process @rag command using the LightRAG pipeline. | |
Optionally, if a file is provided (e.g. PDF or image), one might extract text from it. | |
Here we simply use the query for retrieval-augmented generation. | |
""" | |
context = light_rag.search(query, top_n=5) | |
answer, prompt = light_rag.generate(query, context) | |
return answer | |
def process_granite(query, image: Image.Image): | |
""" | |
Process @granite command: | |
Build a simple prompt from the image and the query then run the Granite Vision model. | |
""" | |
# Here we build a conversation with a single user turn. | |
conversation = [{"role": "user", "content": query}] | |
inputs = vision_processor.apply_chat_template( | |
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
).to(device) | |
generate_kwargs = { | |
"max_new_tokens": MAX_NEW_TOKENS, | |
"do_sample": True, | |
"top_p": TOP_P, | |
"top_k": TOP_K, | |
"temperature": TEMPERATURE, | |
"repetition_penalty": REPETITION_PENALTY, | |
} | |
output = vision_model.generate(**inputs, **generate_kwargs) | |
result = vision_processor.decode(output[0], skip_special_tokens=True) | |
return result.strip() | |
def process_video(query, video_path): | |
""" | |
Process @video-infer command: | |
Downsample the video, process each frame with the Granite Vision model, and combine the results. | |
""" | |
frames = downsample_video(video_path) | |
descriptions = [] | |
for image, timestamp in frames: | |
desc = process_granite(query, image) | |
descriptions.append(f"At {timestamp}s: {desc}") | |
return "\n".join(descriptions) | |
# ------------------------------ | |
# Main function to handle input and dispatch based on command | |
# ------------------------------ | |
def generate_response(input_dict, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
""" | |
Based on the query prefix, this function calls: | |
- process_rag for @rag | |
- process_granite for @granite | |
- process_video for @video-infer | |
If no special command is provided, it defaults to text-only generation via LightRAG. | |
""" | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
lower_text = text.strip().lower() | |
if lower_text.startswith("@rag"): | |
query = text[len("@rag"):].strip() | |
file_path = files[0] if files else None # Optionally process the provided file | |
answer = process_rag(query, file_path) | |
return answer | |
elif lower_text.startswith("@granite"): | |
query = text[len("@granite"):].strip() | |
if files: | |
# Assume first file is an image | |
image = load_image(files[0]) | |
result = process_granite(query, image) | |
return result | |
else: | |
return "No image file provided for @granite command." | |
elif lower_text.startswith("@video-infer"): | |
query = text[len("@video-infer"):].strip() | |
if files: | |
video_path = files[0] # Assume first file is a video | |
result = process_video(query, video_path) | |
return result | |
else: | |
return "No video file provided for @video-infer command." | |
else: | |
# Default: text-only generation using LightRAG | |
answer, prompt = light_rag.generate(text, context=[]) | |
return answer | |
# ------------------------------ | |
# Build the Gradio interface using a multimodal textbox | |
# ------------------------------ | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=MAX_NEW_TOKENS), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=TEMPERATURE), | |
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=TOP_P), | |
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=TOP_K), | |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=REPETITION_PENALTY), | |
], | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", "pdf", "video"], | |
file_count="multiple", | |
placeholder="Enter your query starting with @rag, @granite, or @video-infer", | |
), | |
examples=[ | |
[{"text": "@rag What was the revenue growth in 2020?"}], | |
[{"text": "@granite Describe the content of this image", "files": ["example_image.png"]}], | |
[{"text": "@video-infer Summarize the event shown in the video", "files": ["example_video.mp4"]}], | |
], | |
cache_examples=False, | |
type="messages", | |
description=( | |
"### Multimodal OCR Granite Vision\n" | |
"Use **@rag** for PDF/image RAG, **@granite** for image questions, and **@video-infer** for video understanding." | |
), | |
fill_height=True, | |
stop_btn="Stop Generation", | |
theme="default", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |