minar09's picture
Update app.py
d27c873 verified
raw
history blame
5.83 kB
import os
import gradio as gr
import fitz # PyMuPDF
import shutil
import json
import torch
from PIL import Image
import re
# Import multimodal and Qwen2-VL models and processor from your dependencies.
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# --- Model Initialization ---
def initialize_models():
"""
Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.
"""
multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
trust_remote_code=True,
torch_dtype=torch.float32
)
qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
trust_remote_code=True
)
return multimodal_rag, qwen_model, qwen_processor
multimodal_rag, qwen_model, qwen_processor = initialize_models()
# --- OCR Function ---
def perform_ocr(image: Image.Image) -> str:
"""
Extracts text from an image using the Qwen2-VL model.
"""
query = "Extract text from the image in its original language."
user_input = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query}
]
}
]
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU for inference
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
# Remove the prompt tokens from the generated output
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ocr_result
# --- Product Parsing Function ---
def parse_product_info(text: str) -> dict:
"""
Parses the combined OCR text into structured product information using Qwen2-VL.
"""
prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys.
Text:
{text}
Return JSON format exactly as:
{{
"name": "product name",
"description": "product description",
"price": numeric_price,
"attributes": {{"key": "value"}}
}}"""
user_input = [{"role": "user", "content": prompt}]
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu")
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512)
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
try:
json_start = parsed_result.find('{')
json_end = parsed_result.rfind('}') + 1
data = json.loads(parsed_result[json_start:json_end])
except Exception as e:
data = {}
return data
# --- PDF Processing Function ---
def process_pdf(pdf_file) -> dict:
"""
Processes a PDF file by converting each page to an image,
performing OCR on each page, and then parsing the combined
text into structured product information.
"""
# Create a temporary directory for the PDF file
temp_dir = "./temp_pdf/"
os.makedirs(temp_dir, exist_ok=True)
pdf_path = os.path.join(temp_dir, pdf_file.name)
with open(pdf_path, "wb") as f:
if hasattr(pdf_file, "file"):
shutil.copyfileobj(pdf_file.file, f)
elif hasattr(pdf_file, "name"):
# In case pdf_file is a path string (unlikely in Gradio, but safe-guard)
shutil.copy(pdf_file.name, pdf_path)
else:
raise TypeError("Invalid file input type.")
# Open the PDF file using PyMuPDF
try:
doc = fitz.open(pdf_path)
except Exception as e:
raise RuntimeError(f"Cannot open PDF file: {e}")
combined_text = ""
# Iterate over each page and extract text via OCR
for page in doc:
try:
# Render page as image; adjust dpi as needed for quality/speed balance
pix = page.get_pixmap(dpi=150)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
page_text = perform_ocr(img)
combined_text += page_text + "\n"
except Exception as e:
print(f"Warning: Failed to process page {page.number + 1}: {e}")
# Parse the combined OCR text into structured product info
product_info = parse_product_info(combined_text)
return product_info
# --- Gradio Interface ---
with gr.Blocks() as interface:
gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>")
with gr.Row():
pdf_input = gr.File(label="Upload PDF File", file_count="single")
extract_btn = gr.Button("Extract Product Info")
output_box = gr.JSON(label="Extracted Product Info")
extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box)
interface.launch(debug=True)