Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import fitz # PyMuPDF | |
import shutil | |
import json | |
import torch | |
from PIL import Image | |
import re | |
# Import multimodal and Qwen2-VL models and processor from your dependencies. | |
from byaldi import RAGMultiModalModel | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
# --- Model Initialization --- | |
def initialize_models(): | |
""" | |
Loads and returns the RAG multimodal and Qwen2-VL models along with the processor. | |
""" | |
multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali") | |
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-2B-Instruct", | |
trust_remote_code=True, | |
torch_dtype=torch.float32 | |
) | |
qwen_processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2-VL-2B-Instruct", | |
trust_remote_code=True | |
) | |
return multimodal_rag, qwen_model, qwen_processor | |
multimodal_rag, qwen_model, qwen_processor = initialize_models() | |
# --- OCR Function --- | |
def perform_ocr(image: Image.Image) -> str: | |
""" | |
Extracts text from an image using the Qwen2-VL model. | |
""" | |
query = "Extract text from the image in its original language." | |
user_input = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": query} | |
] | |
} | |
] | |
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(user_input) | |
model_inputs = qwen_processor( | |
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" | |
).to("cpu") # Use CPU for inference | |
with torch.no_grad(): | |
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000) | |
# Remove the prompt tokens from the generated output | |
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)] | |
ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
return ocr_result | |
# --- Product Parsing Function --- | |
def parse_product_info(text: str) -> dict: | |
""" | |
Parses the combined OCR text into structured product information using Qwen2-VL. | |
""" | |
prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys. | |
Text: | |
{text} | |
Return JSON format exactly as: | |
{{ | |
"name": "product name", | |
"description": "product description", | |
"price": numeric_price, | |
"attributes": {{"key": "value"}} | |
}}""" | |
user_input = [{"role": "user", "content": prompt}] | |
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(user_input) | |
model_inputs = qwen_processor( | |
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" | |
).to("cpu") | |
with torch.no_grad(): | |
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512) | |
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)] | |
parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
try: | |
json_start = parsed_result.find('{') | |
json_end = parsed_result.rfind('}') + 1 | |
data = json.loads(parsed_result[json_start:json_end]) | |
except Exception as e: | |
data = {} | |
return data | |
# --- PDF Processing Function --- | |
def process_pdf(pdf_file) -> dict: | |
""" | |
Processes a PDF file by converting each page to an image, | |
performing OCR on each page, and then parsing the combined | |
text into structured product information. | |
""" | |
# Create a temporary directory for the PDF file | |
temp_dir = "./temp_pdf/" | |
os.makedirs(temp_dir, exist_ok=True) | |
pdf_path = os.path.join(temp_dir, pdf_file.name) | |
with open(pdf_path, "wb") as f: | |
if hasattr(pdf_file, "file"): | |
shutil.copyfileobj(pdf_file.file, f) | |
elif hasattr(pdf_file, "name"): | |
# In case pdf_file is a path string (unlikely in Gradio, but safe-guard) | |
shutil.copy(pdf_file.name, pdf_path) | |
else: | |
raise TypeError("Invalid file input type.") | |
# Open the PDF file using PyMuPDF | |
try: | |
doc = fitz.open(pdf_path) | |
except Exception as e: | |
raise RuntimeError(f"Cannot open PDF file: {e}") | |
combined_text = "" | |
# Iterate over each page and extract text via OCR | |
for page in doc: | |
try: | |
# Render page as image; adjust dpi as needed for quality/speed balance | |
pix = page.get_pixmap(dpi=150) | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
page_text = perform_ocr(img) | |
combined_text += page_text + "\n" | |
except Exception as e: | |
print(f"Warning: Failed to process page {page.number + 1}: {e}") | |
# Parse the combined OCR text into structured product info | |
product_info = parse_product_info(combined_text) | |
return product_info | |
# --- Gradio Interface --- | |
with gr.Blocks() as interface: | |
gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>") | |
with gr.Row(): | |
pdf_input = gr.File(label="Upload PDF File", file_count="single") | |
extract_btn = gr.Button("Extract Product Info") | |
output_box = gr.JSON(label="Extracted Product Info") | |
extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box) | |
interface.launch(debug=True) | |