dhairyashah commited on
Commit
ef4c75d
·
verified ·
1 Parent(s): 3cb5e70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -74
app.py CHANGED
@@ -1,20 +1,14 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
  import tqdm
6
  from PIL import Image
7
  import hashlib
8
  import torch
9
  import fitz
10
- import threading
11
  import gradio as gr
12
- import spaces
13
  import os
14
- from transformers import AutoModel
15
- from transformers import AutoTokenizer
16
  import numpy as np
17
  import json
 
18
 
19
  cache_dir = 'kb_cache'
20
  os.makedirs(cache_dir, exist_ok=True)
@@ -33,10 +27,15 @@ def calculate_md5_from_binary(binary_data):
33
 
34
  @spaces.GPU(duration=100)
35
  def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
 
 
 
36
  global model, tokenizer
37
  model.eval()
38
 
39
- this_cache_dir = os.path.join(cache_dir, 'temp_cache')
 
 
40
  os.makedirs(this_cache_dir, exist_ok=True)
41
 
42
  with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
@@ -73,13 +72,16 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
73
 
74
  return "PDF processed successfully!"
75
 
76
- @spaces.GPU(duration=50)
77
- def retrieve_gradio(query: str, topk: int):
78
  global model, tokenizer
79
 
80
  model.eval()
81
 
82
- target_cache_dir = os.path.join(cache_dir, 'temp_cache')
 
 
 
 
83
 
84
  if not os.path.exists(target_cache_dir):
85
  return None
@@ -95,87 +97,35 @@ def retrieve_gradio(query: str, topk: int):
95
  with torch.no_grad():
96
  query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
97
 
 
 
98
  doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
99
 
100
  similarities = torch.matmul(query_rep, doc_reps_cat.T)
101
 
102
  topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
103
 
104
- topk_doc_ids_np = topk_doc_ids.cpu().numpy()
105
-
106
- images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids_np]
107
 
108
  return images_topk
109
 
110
- device = 'cuda'
111
-
112
- print("emb model load begin...")
113
- model_path = 'RhapsodyAI/minicpm-visual-embedding-v0' # replace with your local model path
114
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
115
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
116
- model.eval()
117
- model.to(device)
118
- print("emb model load success!")
119
-
120
- print("gen model load begin...")
121
- gen_model_path = 'openbmb/MiniCPM-V-2_6'
122
- gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_path, trust_remote_code=True)
123
- gen_model = AutoModel.from_pretrained(gen_model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
124
- gen_model.eval()
125
- gen_model.to(device)
126
- print("gen model load success!")
127
-
128
- @spaces.GPU(duration=50)
129
- def answer_question(images, question):
130
- global gen_model, gen_tokenizer
131
- images_ = [Image.open(image[0]).convert('RGB') for image in images]
132
- msgs = [{'role': 'user', 'content': [question, *images_]}]
133
- answer = gen_model.chat(
134
- image=None,
135
- msgs=msgs,
136
- tokenizer=gen_tokenizer
137
- )
138
- print(answer)
139
- return answer
140
 
141
  with gr.Blocks() as app:
142
- gr.Markdown("# MiniCPMV-RAG-PDFQA: Two Vision Language Models Enable End-to-End RAG")
143
-
144
- gr.Markdown("""
145
- - A Vision Language Model Dense Retriever ([minicpm-visual-embedding-v0](https://huggingface.co/RhapsodyAI/minicpm-visual-embedding-v0)) **directly reads** your PDFs **without need of OCR**, produce **multimodal dense representations** and build your personal library.
146
-
147
- - **Ask a question**, it retrieves the most relevant pages, then [MiniCPM-V-2.6](https://huggingface.co/spaces/openbmb/MiniCPM-V-2_6) will answer your question based on pages recalled, with strong multi-image understanding capability.
148
-
149
- - It helps you read a long **visually-intensive** or **text-oriented** PDF document and find the pages that answer your question.
150
 
151
- - It helps you build a personal library and retrieve book pages from a large collection of books.
152
-
153
- - It works like a human: read, store, retrieve, and answer with full vision.
154
- """)
155
-
156
- gr.Markdown("- Currently online demo support PDF document with less than 50 pages due to GPU time limit. Deploy on your own machine for longer PDFs and books.")
157
-
158
  with gr.Row():
159
- file_input = gr.File(type="binary", label="Step 1: Upload PDF")
160
  process_button = gr.Button("Process PDF")
161
- file_result = gr.Textbox(label="PDF Process Status")
162
 
163
- process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
164
 
165
  with gr.Row():
166
  query_input = gr.Text(label="Your Question")
167
- topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of Pages to Retrieve")
168
  retrieve_button = gr.Button("Retrieve Pages")
169
- images_output = gr.Gallery(label="Retrieved Pages")
170
 
171
- retrieve_button.click(retrieve_gradio, inputs=[query_input, topk_input], outputs=images_output)
172
-
173
- with gr.Row():
174
- answer_button = gr.Button("Answer Question")
175
- gen_model_response = gr.Textbox(label="MiniCPM-V-2.6's Answer")
176
-
177
- answer_button.click(fn=answer_question, inputs=[images_output, query_input], outputs=gen_model_response)
178
 
179
- gr.Markdown("By using this demo, you agree to share your use data with us for research purpose, to help improve user experience.")
180
 
181
- app.launch()
 
 
 
 
 
1
  import tqdm
2
  from PIL import Image
3
  import hashlib
4
  import torch
5
  import fitz
 
6
  import gradio as gr
 
7
  import os
8
+ from transformers import AutoModel, AutoTokenizer
 
9
  import numpy as np
10
  import json
11
+ import spaces
12
 
13
  cache_dir = 'kb_cache'
14
  os.makedirs(cache_dir, exist_ok=True)
 
27
 
28
  @spaces.GPU(duration=100)
29
  def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
30
+ if pdf_file_binary is None:
31
+ return "No PDF file uploaded."
32
+
33
  global model, tokenizer
34
  model.eval()
35
 
36
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
37
+
38
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
39
  os.makedirs(this_cache_dir, exist_ok=True)
40
 
41
  with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
 
72
 
73
  return "PDF processed successfully!"
74
 
75
+ def retrieve_gradio(pdf_file_binary, query: str, topk: int):
 
76
  global model, tokenizer
77
 
78
  model.eval()
79
 
80
+ if pdf_file_binary is None:
81
+ return "No PDF file uploaded."
82
+
83
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
84
+ target_cache_dir = os.path.join(cache_dir, knowledge_base_name)
85
 
86
  if not os.path.exists(target_cache_dir):
87
  return None
 
97
  with torch.no_grad():
98
  query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
99
 
100
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
101
+
102
  doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
103
 
104
  similarities = torch.matmul(query_rep, doc_reps_cat.T)
105
 
106
  topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
107
 
108
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()]
 
 
109
 
110
  return images_topk
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  with gr.Blocks() as app:
114
+ gr.Markdown("# MiniCPMV-RAG-PDFQA")
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
116
  with gr.Row():
117
+ file_input = gr.File(type="binary", label="Upload PDF")
118
  process_button = gr.Button("Process PDF")
 
119
 
120
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs="text")
121
 
122
  with gr.Row():
123
  query_input = gr.Text(label="Your Question")
124
+ topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
125
  retrieve_button = gr.Button("Retrieve Pages")
 
126
 
127
+ images_output = gr.Gallery(label="Retrieved Pages")
 
 
 
 
 
 
128
 
129
+ retrieve_button.click(retrieve_gradio, inputs=[file_input, query_input, topk_input], outputs=images_output)
130
 
131
+ app.launch(share=True)