bokesyo commited on
Commit
cc886ab
1 Parent(s): 4d1a2ae
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -79,7 +79,6 @@ class PDFVisualRetrieval:
79
  self.images[knowledge_base_name][image_md5] = image
80
  return
81
 
82
- @spaces.GPU
83
  def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
84
  knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
85
  if knowledge_base_name not in self.reps:
@@ -90,18 +89,20 @@ class PDFVisualRetrieval:
90
  self.images[knowledge_base_name] = {}
91
  dpi = 100
92
  doc = fitz.open("pdf", pdf_file_binary)
93
- for page in progress.tqdm(doc):
94
- with self.lock: # because we hope one 16G gpu only process one image at the same time
95
- pix = page.get_pixmap(dpi=dpi)
96
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
97
- image_md5 = get_image_md5(image)
98
- with torch.no_grad():
99
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
100
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
101
- self.images[knowledge_base_name][image_md5] = image
 
 
 
102
  return knowledge_base_name
103
 
104
- @spaces.GPU
105
  def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
106
  doc_reps = list(self.reps[knowledge_base].values())
107
  query_with_instruction = "Represent this query for retrieving relavant document: " + query
 
79
  self.images[knowledge_base_name][image_md5] = image
80
  return
81
 
 
82
  def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
83
  knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
84
  if knowledge_base_name not in self.reps:
 
89
  self.images[knowledge_base_name] = {}
90
  dpi = 100
91
  doc = fitz.open("pdf", pdf_file_binary)
92
+
93
+ with spaces.GPU():
94
+ for page in progress.tqdm(doc):
95
+ with self.lock: # because we hope one 16G gpu only process one image at the same time
96
+ pix = page.get_pixmap(dpi=dpi)
97
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
98
+ image_md5 = get_image_md5(image)
99
+ with torch.no_grad():
100
+ reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
101
+ self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
102
+ self.images[knowledge_base_name][image_md5] = image
103
+
104
  return knowledge_base_name
105
 
 
106
  def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
107
  doc_reps = list(self.reps[knowledge_base].values())
108
  query_with_instruction = "Represent this query for retrieving relavant document: " + query