File size: 5,831 Bytes
43355d2
 
d27c873
17345fb
d27c873
 
 
 
43355d2
d27c873
 
 
 
4bc3210
d27c873
4bc3210
d27c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17345fb
d27c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17345fb
d27c873
 
 
 
 
 
17345fb
d27c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17345fb
d27c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43355d2
d27c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import gradio as gr
import fitz  # PyMuPDF
import shutil
import json
import torch
from PIL import Image
import re

# Import multimodal and Qwen2-VL models and processor from your dependencies.
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# --- Model Initialization ---

def initialize_models():
    """
    Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.
    """
    multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
    qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        trust_remote_code=True,
        torch_dtype=torch.float32
    )
    qwen_processor = AutoProcessor.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        trust_remote_code=True
    )
    return multimodal_rag, qwen_model, qwen_processor

multimodal_rag, qwen_model, qwen_processor = initialize_models()

# --- OCR Function ---
def perform_ocr(image: Image.Image) -> str:
    """
    Extracts text from an image using the Qwen2-VL model.
    """
    query = "Extract text from the image in its original language."
    user_input = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": query}
            ]
        }
    ]
    input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(user_input)
    model_inputs = qwen_processor(
        text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
    ).to("cpu")  # Use CPU for inference
    with torch.no_grad():
        generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
        # Remove the prompt tokens from the generated output
        trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
        ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return ocr_result

# --- Product Parsing Function ---
def parse_product_info(text: str) -> dict:
    """
    Parses the combined OCR text into structured product information using Qwen2-VL.
    """
    prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys.

Text:
{text}

Return JSON format exactly as:
{{
    "name": "product name",
    "description": "product description",
    "price": numeric_price,
    "attributes": {{"key": "value"}}
}}"""
    user_input = [{"role": "user", "content": prompt}]
    input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(user_input)
    model_inputs = qwen_processor(
        text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
    ).to("cpu")
    with torch.no_grad():
        generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512)
        trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
        parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    try:
        json_start = parsed_result.find('{')
        json_end = parsed_result.rfind('}') + 1
        data = json.loads(parsed_result[json_start:json_end])
    except Exception as e:
        data = {}
    return data

# --- PDF Processing Function ---
def process_pdf(pdf_file) -> dict:
    """
    Processes a PDF file by converting each page to an image,
    performing OCR on each page, and then parsing the combined
    text into structured product information.
    """
    # Create a temporary directory for the PDF file
    temp_dir = "./temp_pdf/"
    os.makedirs(temp_dir, exist_ok=True)
    pdf_path = os.path.join(temp_dir, pdf_file.name)
    with open(pdf_path, "wb") as f:
        if hasattr(pdf_file, "file"):
            shutil.copyfileobj(pdf_file.file, f)
        elif hasattr(pdf_file, "name"):
            # In case pdf_file is a path string (unlikely in Gradio, but safe-guard)
            shutil.copy(pdf_file.name, pdf_path)
        else:
            raise TypeError("Invalid file input type.")
    
    # Open the PDF file using PyMuPDF
    try:
        doc = fitz.open(pdf_path)
    except Exception as e:
        raise RuntimeError(f"Cannot open PDF file: {e}")
    
    combined_text = ""
    # Iterate over each page and extract text via OCR
    for page in doc:
        try:
            # Render page as image; adjust dpi as needed for quality/speed balance
            pix = page.get_pixmap(dpi=150)
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            page_text = perform_ocr(img)
            combined_text += page_text + "\n"
        except Exception as e:
            print(f"Warning: Failed to process page {page.number + 1}: {e}")
    
    # Parse the combined OCR text into structured product info
    product_info = parse_product_info(combined_text)
    return product_info

# --- Gradio Interface ---
with gr.Blocks() as interface:
    gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>")
    with gr.Row():
        pdf_input = gr.File(label="Upload PDF File", file_count="single")
        extract_btn = gr.Button("Extract Product Info")
    output_box = gr.JSON(label="Extracted Product Info")
    extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box)

interface.launch(debug=True)