bokesyo commited on
Commit
b3fffcd
1 Parent(s): cc886ab
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -90,16 +90,15 @@ class PDFVisualRetrieval:
90
  dpi = 100
91
  doc = fitz.open("pdf", pdf_file_binary)
92
 
93
- with spaces.GPU():
94
- for page in progress.tqdm(doc):
95
- with self.lock: # because we hope one 16G gpu only process one image at the same time
96
- pix = page.get_pixmap(dpi=dpi)
97
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
98
- image_md5 = get_image_md5(image)
99
- with torch.no_grad():
100
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
101
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
102
- self.images[knowledge_base_name][image_md5] = image
103
 
104
  return knowledge_base_name
105
 
@@ -137,6 +136,14 @@ if __name__ == "__main__":
137
 
138
  retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
139
 
 
 
 
 
 
 
 
 
140
  # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
141
  # # 2
142
  # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
@@ -152,7 +159,7 @@ if __name__ == "__main__":
152
  file_result = gr.Text(label="Knowledge Base ID (remember this!)")
153
  process_button = gr.Button("Process PDF")
154
 
155
- process_button.click(retriever.add_pdf_gradio, inputs=[file_input], outputs=file_result)
156
 
157
  with gr.Row():
158
  kb_id_input = gr.Text(label="Your Knowledge Base ID")
@@ -163,7 +170,7 @@ if __name__ == "__main__":
163
  with gr.Row():
164
  images_output = gr.Gallery(label="Retrieved Pages")
165
 
166
- retrieve_button.click(retriever.retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
167
 
168
  app.launch()
169
 
 
90
  dpi = 100
91
  doc = fitz.open("pdf", pdf_file_binary)
92
 
93
+ for page in progress.tqdm(doc):
94
+ # with self.lock: # because we hope one 16G gpu only process one image at the same time
95
+ pix = page.get_pixmap(dpi=dpi)
96
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
97
+ image_md5 = get_image_md5(image)
98
+ with torch.no_grad():
99
+ reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
100
+ self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
101
+ self.images[knowledge_base_name][image_md5] = image
 
102
 
103
  return knowledge_base_name
104
 
 
136
 
137
  retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
138
 
139
+ @spaces.GPU
140
+ def add_pdf_gradio(pdf_file_binary):
141
+ return retriever.add_pdf_gradio(pdf_file_binary)
142
+
143
+ @spaces.GPU
144
+ def retrieve_gradio(knowledge_base, query, topk):
145
+ return retriever.retrieve_gradio(knowledge_base, query, topk)
146
+
147
  # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
148
  # # 2
149
  # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
 
159
  file_result = gr.Text(label="Knowledge Base ID (remember this!)")
160
  process_button = gr.Button("Process PDF")
161
 
162
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
163
 
164
  with gr.Row():
165
  kb_id_input = gr.Text(label="Your Knowledge Base ID")
 
170
  with gr.Row():
171
  images_output = gr.Gallery(label="Retrieved Pages")
172
 
173
+ retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
174
 
175
  app.launch()
176