manu commited on
Commit
654c2e1
·
verified ·
1 Parent(s): d5db6a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -38
app.py CHANGED
@@ -8,43 +8,9 @@ from torch.utils.data import DataLoader
8
  from tqdm import tqdm
9
  from transformers import AutoProcessor
10
 
11
- from custom_colbert.models.paligemma_colbert_architecture import ColPali
12
- from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator
13
-
14
-
15
- def process_images(processor, images, max_length: int = 50):
16
- texts_doc = ["Describe the image."] * len(images)
17
- images = [image.convert("RGB") for image in images]
18
-
19
- batch_doc = processor(
20
- text=texts_doc,
21
- images=images,
22
- return_tensors="pt",
23
- padding="longest",
24
- max_length=max_length + processor.image_seq_length,
25
- )
26
- return batch_doc
27
-
28
-
29
- def process_queries(processor, queries, mock_image, max_length: int = 50):
30
- texts_query = []
31
- for query in queries:
32
- query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
33
- texts_query.append(query)
34
-
35
- batch_query = processor(
36
- images=[mock_image.convert("RGB")] * len(texts_query),
37
- # NOTE: the image is not used in batch_query but it is required for calling the processor
38
- text=texts_query,
39
- return_tensors="pt",
40
- padding="longest",
41
- max_length=max_length + processor.image_seq_length,
42
- )
43
- del batch_query["pixel_values"]
44
-
45
- batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
46
- batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
47
- return batch_query
48
 
49
 
50
  def search(query: str, ds, images):
@@ -71,7 +37,7 @@ def index(file, ds):
71
  # run inference - docs
72
  dataloader = DataLoader(
73
  images,
74
- batch_size=8,
75
  shuffle=False,
76
  collate_fn=lambda x: process_images(processor, x),
77
  )
 
8
  from tqdm import tqdm
9
  from transformers import AutoProcessor
10
 
11
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
12
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
13
+ from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def search(query: str, ds, images):
 
37
  # run inference - docs
38
  dataloader = DataLoader(
39
  images,
40
+ batch_size=4,
41
  shuffle=False,
42
  collate_fn=lambda x: process_images(processor, x),
43
  )