File size: 4,077 Bytes
2c8b539
 
5156ae8
2c8b539
 
a0e2927
2c8b539
ef4c75d
2c8b539
5156ae8
ef4c75d
2c8b539
3cb5e70
2c8b539
 
5156ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4c75d
 
 
5156ae8
 
 
ef4c75d
 
 
5156ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4c75d
5156ae8
 
 
 
ef4c75d
 
 
 
 
5156ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4c75d
 
5156ae8
 
 
 
 
 
ef4c75d
5156ae8
 
 
2c8b539
 
ef4c75d
5156ae8
2c8b539
ef4c75d
2c8b539
5156ae8
ef4c75d
2c8b539
 
 
ef4c75d
5156ae8
 
ef4c75d
5156ae8
ef4c75d
a0e2927
ef4c75d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tqdm
from PIL import Image
import hashlib
import torch
import fitz
import gradio as gr
import os
from transformers import AutoModel, AutoTokenizer
import numpy as np
import json
import spaces

cache_dir = 'kb_cache'
os.makedirs(cache_dir, exist_ok=True)

def get_image_md5(img: Image.Image):
    img_byte_array = img.tobytes()
    hash_md5 = hashlib.md5()
    hash_md5.update(img_byte_array)
    hex_digest = hash_md5.hexdigest()
    return hex_digest

def calculate_md5_from_binary(binary_data):
    hash_md5 = hashlib.md5()
    hash_md5.update(binary_data)
    return hash_md5.hexdigest()

@spaces.GPU(duration=100)
def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
    if pdf_file_binary is None:
        return "No PDF file uploaded."

    global model, tokenizer
    model.eval()
    
    knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
    
    this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
    os.makedirs(this_cache_dir, exist_ok=True)

    with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
        file.write(pdf_file_binary)

    dpi = 200
    doc = fitz.open("pdf", pdf_file_binary)
    
    reps_list = []
    images = []
    image_md5s = []

    for page in progress.tqdm(doc):
        pix = page.get_pixmap(dpi=dpi)
        image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        image_md5 = get_image_md5(image)
        image_md5s.append(image_md5)
        with torch.no_grad():
            reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
        reps_list.append(reps.squeeze(0).cpu().numpy())
        images.append(image)

    for idx in range(len(images)):
        image = images[idx]
        image_md5 = image_md5s[idx]
        cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
        image.save(cache_image_path)

    np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)

    with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
        for item in image_md5s:
            f.write(item+'\n')
    
    return "PDF processed successfully!"

def retrieve_gradio(pdf_file_binary, query: str, topk: int):
    global model, tokenizer

    model.eval()

    if pdf_file_binary is None:
        return "No PDF file uploaded."

    knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
    target_cache_dir = os.path.join(cache_dir, knowledge_base_name)

    if not os.path.exists(target_cache_dir):
        return None
    
    md5s = []
    with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
        for line in f:
            md5s.append(line.rstrip('\n'))
    
    doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))

    query_with_instruction = "Represent this query for retrieving relevant document: " + query
    with torch.no_grad():
        query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()

    query_md5 = hashlib.md5(query.encode()).hexdigest()

    doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)

    similarities = torch.matmul(query_rep, doc_reps_cat.T)

    topk_values, topk_doc_ids = torch.topk(similarities, k=topk)

    images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()]

    return images_topk


with gr.Blocks() as app:
    gr.Markdown("# MiniCPMV-RAG-PDFQA")

    with gr.Row():
        file_input = gr.File(type="binary", label="Upload PDF")
        process_button = gr.Button("Process PDF")
    
    process_button.click(add_pdf_gradio, inputs=[file_input], outputs="text")

    with gr.Row():
        query_input = gr.Text(label="Your Question")
        topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
        retrieve_button = gr.Button("Retrieve Pages")
    
    images_output = gr.Gallery(label="Retrieved Pages")
    
    retrieve_button.click(retrieve_gradio, inputs=[file_input, query_input, topk_input], outputs=images_output)

app.launch(share=True)