Update processing_colqwen2.py
Browse files- processing_colqwen2.py +39 -3
processing_colqwen2.py
CHANGED
@@ -6,10 +6,8 @@ from PIL import Image
|
|
6 |
from transformers import BatchFeature
|
7 |
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
8 |
|
9 |
-
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
|
10 |
|
11 |
-
|
12 |
-
class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
|
13 |
"""
|
14 |
Processor for ColQwen2.
|
15 |
"""
|
@@ -148,3 +146,41 @@ class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
|
|
148 |
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
149 |
"""
|
150 |
return self.score_multi_vector(qs, ps, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from transformers import BatchFeature
|
7 |
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
8 |
|
|
|
9 |
|
10 |
+
class ColQwen2Processor(Qwen2VLProcessor):
|
|
|
11 |
"""
|
12 |
Processor for ColQwen2.
|
13 |
"""
|
|
|
146 |
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
147 |
"""
|
148 |
return self.score_multi_vector(qs, ps, device=device, **kwargs)
|
149 |
+
|
150 |
+
|
151 |
+
def score_multi_vector(
|
152 |
+
qs: List[torch.Tensor],
|
153 |
+
ps: List[torch.Tensor],
|
154 |
+
batch_size: int = 128,
|
155 |
+
device: Optional[Union[str, torch.device]] = None,
|
156 |
+
) -> torch.Tensor:
|
157 |
+
"""
|
158 |
+
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
159 |
+
"""
|
160 |
+
device = device or get_torch_device("auto")
|
161 |
+
|
162 |
+
if len(qs) == 0:
|
163 |
+
raise ValueError("No queries provided")
|
164 |
+
if len(ps) == 0:
|
165 |
+
raise ValueError("No passages provided")
|
166 |
+
|
167 |
+
scores_list: List[torch.Tensor] = []
|
168 |
+
|
169 |
+
for i in range(0, len(qs), batch_size):
|
170 |
+
scores_batch = []
|
171 |
+
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
|
172 |
+
device
|
173 |
+
)
|
174 |
+
for j in range(0, len(ps), batch_size):
|
175 |
+
ps_batch = torch.nn.utils.rnn.pad_sequence(
|
176 |
+
ps[j : j + batch_size], batch_first=True, padding_value=0
|
177 |
+
).to(device)
|
178 |
+
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
|
179 |
+
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
180 |
+
scores_list.append(scores_batch)
|
181 |
+
|
182 |
+
scores = torch.cat(scores_list, dim=0)
|
183 |
+
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
184 |
+
|
185 |
+
scores = scores.to(torch.float32)
|
186 |
+
return scores
|