manu commited on
Commit
c0b2913
·
verified ·
1 Parent(s): 7c5bf94

Update processing_colqwen2.py

Browse files
Files changed (1) hide show
  1. processing_colqwen2.py +39 -3
processing_colqwen2.py CHANGED
@@ -6,10 +6,8 @@ from PIL import Image
6
  from transformers import BatchFeature
7
  from transformers.models.qwen2_vl import Qwen2VLProcessor
8
 
9
- from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
10
 
11
-
12
- class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
13
  """
14
  Processor for ColQwen2.
15
  """
@@ -148,3 +146,41 @@ class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
148
  Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
149
  """
150
  return self.score_multi_vector(qs, ps, device=device, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from transformers import BatchFeature
7
  from transformers.models.qwen2_vl import Qwen2VLProcessor
8
 
 
9
 
10
+ class ColQwen2Processor(Qwen2VLProcessor):
 
11
  """
12
  Processor for ColQwen2.
13
  """
 
146
  Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
147
  """
148
  return self.score_multi_vector(qs, ps, device=device, **kwargs)
149
+
150
+
151
+ def score_multi_vector(
152
+ qs: List[torch.Tensor],
153
+ ps: List[torch.Tensor],
154
+ batch_size: int = 128,
155
+ device: Optional[Union[str, torch.device]] = None,
156
+ ) -> torch.Tensor:
157
+ """
158
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
159
+ """
160
+ device = device or get_torch_device("auto")
161
+
162
+ if len(qs) == 0:
163
+ raise ValueError("No queries provided")
164
+ if len(ps) == 0:
165
+ raise ValueError("No passages provided")
166
+
167
+ scores_list: List[torch.Tensor] = []
168
+
169
+ for i in range(0, len(qs), batch_size):
170
+ scores_batch = []
171
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
172
+ device
173
+ )
174
+ for j in range(0, len(ps), batch_size):
175
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
176
+ ps[j : j + batch_size], batch_first=True, padding_value=0
177
+ ).to(device)
178
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
179
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
180
+ scores_list.append(scores_batch)
181
+
182
+ scores = torch.cat(scores_list, dim=0)
183
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
184
+
185
+ scores = scores.to(torch.float32)
186
+ return scores