Spaces:
Build error
Build error
File size: 5,831 Bytes
43355d2 d27c873 17345fb d27c873 43355d2 d27c873 4bc3210 d27c873 4bc3210 d27c873 17345fb d27c873 17345fb d27c873 17345fb d27c873 17345fb d27c873 43355d2 d27c873 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import gradio as gr
import fitz # PyMuPDF
import shutil
import json
import torch
from PIL import Image
import re
# Import multimodal and Qwen2-VL models and processor from your dependencies.
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# --- Model Initialization ---
def initialize_models():
"""
Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.
"""
multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
trust_remote_code=True,
torch_dtype=torch.float32
)
qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
trust_remote_code=True
)
return multimodal_rag, qwen_model, qwen_processor
multimodal_rag, qwen_model, qwen_processor = initialize_models()
# --- OCR Function ---
def perform_ocr(image: Image.Image) -> str:
"""
Extracts text from an image using the Qwen2-VL model.
"""
query = "Extract text from the image in its original language."
user_input = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query}
]
}
]
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU for inference
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
# Remove the prompt tokens from the generated output
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ocr_result
# --- Product Parsing Function ---
def parse_product_info(text: str) -> dict:
"""
Parses the combined OCR text into structured product information using Qwen2-VL.
"""
prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys.
Text:
{text}
Return JSON format exactly as:
{{
"name": "product name",
"description": "product description",
"price": numeric_price,
"attributes": {{"key": "value"}}
}}"""
user_input = [{"role": "user", "content": prompt}]
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu")
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512)
trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
try:
json_start = parsed_result.find('{')
json_end = parsed_result.rfind('}') + 1
data = json.loads(parsed_result[json_start:json_end])
except Exception as e:
data = {}
return data
# --- PDF Processing Function ---
def process_pdf(pdf_file) -> dict:
"""
Processes a PDF file by converting each page to an image,
performing OCR on each page, and then parsing the combined
text into structured product information.
"""
# Create a temporary directory for the PDF file
temp_dir = "./temp_pdf/"
os.makedirs(temp_dir, exist_ok=True)
pdf_path = os.path.join(temp_dir, pdf_file.name)
with open(pdf_path, "wb") as f:
if hasattr(pdf_file, "file"):
shutil.copyfileobj(pdf_file.file, f)
elif hasattr(pdf_file, "name"):
# In case pdf_file is a path string (unlikely in Gradio, but safe-guard)
shutil.copy(pdf_file.name, pdf_path)
else:
raise TypeError("Invalid file input type.")
# Open the PDF file using PyMuPDF
try:
doc = fitz.open(pdf_path)
except Exception as e:
raise RuntimeError(f"Cannot open PDF file: {e}")
combined_text = ""
# Iterate over each page and extract text via OCR
for page in doc:
try:
# Render page as image; adjust dpi as needed for quality/speed balance
pix = page.get_pixmap(dpi=150)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
page_text = perform_ocr(img)
combined_text += page_text + "\n"
except Exception as e:
print(f"Warning: Failed to process page {page.number + 1}: {e}")
# Parse the combined OCR text into structured product info
product_info = parse_product_info(combined_text)
return product_info
# --- Gradio Interface ---
with gr.Blocks() as interface:
gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>")
with gr.Row():
pdf_input = gr.File(label="Upload PDF File", file_count="single")
extract_btn = gr.Button("Extract Product Info")
output_box = gr.JSON(label="Extracted Product Info")
extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box)
interface.launch(debug=True)
|