Spaces:
Running
on
Zero
Running
on
Zero
""" | |
app.py | |
A unified Gradio chat application for Multimodal OCR Granite Vision. | |
Commands (enter these as a prefix in the text input): | |
- @rag: For retrieval‐augmented generation (e.g. PDF or text-based queries). | |
- @granite: For image understanding. | |
- @video-infer: For video understanding (video is downsampled into frames). | |
The app uses gr.MultimodalTextbox to support text input together with file uploads. | |
""" | |
import os | |
import time | |
import uuid | |
import random | |
import logging | |
from threading import Thread | |
from pathlib import Path | |
from datetime import datetime, timezone | |
import torch | |
import spaces | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
AutoProcessor, | |
AutoModelForVision2Seq, | |
) | |
# --------------------------- | |
# Utility functions and setup | |
# --------------------------- | |
def get_device(): | |
if torch.backends.mps.is_available(): | |
return "mps" # mac GPU | |
elif torch.cuda.is_available(): | |
return "cuda" | |
else: | |
return "cpu" | |
device = get_device() | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
def downsample_video(video_path): | |
""" | |
Downsamples the video into 10 evenly spaced frames. | |
Returns a list of (PIL Image, timestamp in seconds) tuples. | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
# --------------------------- | |
# HF Embedding and LLM classes | |
# --------------------------- | |
class HFEmbedding: | |
def __init__(self, model_id: str): | |
self.model_name = model_id | |
logging.info(f"Loading embeddings model from: {self.model_name}") | |
# Using langchain_huggingface for embeddings | |
from langchain_huggingface import HuggingFaceEmbeddings # ensure installed | |
# For simplicity, force CPU (adjust if needed) | |
self.embeddings_service = HuggingFaceEmbeddings( | |
model_name=self.model_name, | |
model_kwargs={"device": "cpu"}, | |
) | |
def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
return self.embeddings_service.embed_documents(texts) | |
def embed_query(self, text: str) -> list[float]: | |
return self.embed_documents([text])[0] | |
class HFLLM: | |
def __init__(self, model_name: str): | |
self.device = device | |
self.model_name = model_name | |
logging.info("Loading HF language model...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) | |
def generate(self, prompt: str) -> list: | |
# Tokenize prompt and generate text | |
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024) | |
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) | |
# Extract answer assuming a marker in the generated text | |
response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}] | |
return response | |
# --------------------------- | |
# LightRAG: Retrieval-Augmented Generation (Dummy) | |
# --------------------------- | |
class LightRAG: | |
def __init__(self, config: dict): | |
self.config = config | |
# Load generation and embedding models immediately (or lazy load as needed) | |
self.gen_model = HFLLM(config['generation_model_id']) | |
self._embedding_model = HFEmbedding(config['embedding_model_id']) | |
def search(self, query: str, top_n: int = 5) -> list: | |
# Dummy retrieval: In practice, integrate with a vector store | |
from langchain_core.documents import Document # ensure langchain_core is installed | |
dummy_doc = Document( | |
page_content="Dummy context for query: " + query, | |
metadata={"type": "text"} | |
) | |
return [dummy_doc] | |
def generate(self, query, context=None): | |
if context is None: | |
context = [] | |
# Build prompt by concatenating retrieved context with the query. | |
prompt = "" | |
for doc in context: | |
prompt += doc.page_content + "\n" | |
prompt += "\nQuestion: " + query + "\nAnswer:" | |
results = self.gen_model.generate(prompt) | |
answer = results[0]["answer"] | |
return answer, prompt | |
# Global configuration for LightRAG | |
rag_config = { | |
"embedding_model_id": "ibm-granite/granite-embedding-125m-english", | |
"generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
} | |
light_rag = LightRAG(rag_config) | |
# --------------------------- | |
# Granite Vision functions (for image and video) | |
# --------------------------- | |
# Set the Granite Vision model ID (adjust version as needed) | |
GRANITE_MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
granite_processor = None | |
granite_model = None | |
def load_granite_model(): | |
"""Lazy load the Granite vision processor and model.""" | |
global granite_processor, granite_model | |
if granite_processor is None or granite_model is None: | |
granite_processor = AutoProcessor.from_pretrained(GRANITE_MODEL_ID) | |
granite_model = AutoModelForVision2Seq.from_pretrained(GRANITE_MODEL_ID, device_map="auto").to(device) | |
return granite_processor, granite_model | |
def create_single_turn(image, text): | |
""" | |
Creates a single-turn conversation message. | |
If an image is provided, it is added along with the text. | |
""" | |
if image is None: | |
return {"role": "user", "content": [{"type": "text", "text": text}]} | |
else: | |
return {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]} | |
def generate_granite(image, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05): | |
""" | |
Generates a response from the Granite Vision model given an image and prompt. | |
""" | |
processor, model = load_granite_model() | |
conversation = [create_single_turn(image, prompt_text)] | |
inputs = processor.apply_chat_template( | |
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
).to(device) | |
output = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
decoded = processor.decode(output[0], skip_special_tokens=True) | |
parts = decoded.strip().split("<|assistant|>") | |
return parts[-1].strip() | |
def generate_video_infer(video_path, prompt_text, max_new_tokens=1024, temperature=0.7, top_p=0.85, top_k=50, repetition_penalty=1.05): | |
""" | |
Processes a video file by downsampling frames and sending them along with a prompt | |
to the Granite Vision model. | |
""" | |
frames = downsample_video(video_path) | |
conversation_content = [] | |
for img, ts in frames: | |
conversation_content.append({"type": "text", "text": f"Frame at {ts} sec:"}) | |
conversation_content.append({"type": "image", "image": img}) | |
conversation_content.append({"type": "text", "text": prompt_text}) | |
conversation = [{"role": "user", "content": conversation_content}] | |
processor, model = load_granite_model() | |
inputs = processor.apply_chat_template( | |
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" | |
).to(device) | |
output = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
decoded = processor.decode(output[0], skip_special_tokens=True) | |
parts = decoded.strip().split("<|assistant|>") | |
return parts[-1].strip() | |
# --------------------------- | |
# Unified generation function for ChatInterface | |
# --------------------------- | |
def generate(input_dict: dict, chat_history: list[dict], | |
max_new_tokens: int, temperature: float, | |
top_p: float, top_k: int, repetition_penalty: float): | |
""" | |
Chat function that inspects the input text for special commands and routes: | |
- @rag: Uses the RAG pipeline. | |
- @granite: Uses Granite Vision for image understanding. | |
- @video-infer: Uses Granite Vision for video processing. | |
""" | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
lower_text = text.strip().lower() | |
# Optionally yield a progress message | |
yield "Processing your request..." | |
time.sleep(1) # simulate processing delay | |
if lower_text.startswith("@rag"): | |
query = text[len("@rag"):].strip() | |
logging.info(f"@rag command: {query}") | |
context = light_rag.search(query) | |
answer, _ = light_rag.generate(query, context) | |
yield answer | |
elif lower_text.startswith("@granite"): | |
prompt_text = text[len("@granite"):].strip() | |
logging.info(f"@granite command: {prompt_text}") | |
if files: | |
# Expecting an image file (as a PIL image) | |
image = files[0] | |
answer = generate_granite(image, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty) | |
yield answer | |
else: | |
yield "No image provided for @granite command." | |
elif lower_text.startswith("@video-infer"): | |
prompt_text = text[len("@video-infer"):].strip() | |
logging.info(f"@video-infer command: {prompt_text}") | |
if files: | |
# Expecting a video file (the file path) | |
video_path = files[0] | |
answer = generate_video_infer(video_path, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty) | |
yield answer | |
else: | |
yield "No video provided for @video-infer command." | |
else: | |
# Default behavior: use RAG pipeline for text query. | |
query = text.strip() | |
logging.info(f"Default text query: {query}") | |
context = light_rag.search(query) | |
answer, _ = light_rag.generate(query, context) | |
yield answer | |
# --------------------------- | |
# Gradio ChatInterface using MultimodalTextbox | |
# --------------------------- | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=1024), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), | |
gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.1, value=0.85), | |
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), | |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.05), | |
], | |
examples=[ | |
# Examples show how to use the command prefixes. | |
[{"text": "@rag What models are available in Watsonx?"}], | |
[{"text": "@granite Describe the image", "files": [str(Path("examples") / "sample_image.png")]}], | |
[{"text": "@video-infer Summarize the event in the video", "files": [str(Path("examples") / "sample_video.mp4")]}], | |
], | |
cache_examples=False, | |
type="messages", | |
description=( | |
"# **Multimodal OCR Granite Vision**\n\n" | |
"Enter a command in the text input (with optional file uploads) using one of the following prefixes:\n\n" | |
"- **@rag**: For retrieval-augmented generation (e.g. PDFs, documents).\n" | |
"- **@granite**: For image understanding using Granite Vision.\n" | |
"- **@video-infer**: For video understanding (video is downsampled into frames).\n\n" | |
"For example:\n```\n@rag What is the revenue trend?\n```\n```\n@granite Describe this image\n```\n```\n@video-infer Summarize the event in this video\n```" | |
), | |
fill_height=True, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", "video", "pdf"], | |
file_count="multiple", | |
placeholder="@rag, @granite, or @video-infer followed by your prompt" | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |