jupyterjazz
commited on
Commit
•
c55e591
1
Parent(s):
b27fa55
refactor: truncation fn
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- modeling_xlm_roberta.py +14 -9
modeling_xlm_roberta.py
CHANGED
@@ -579,15 +579,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
580 |
|
581 |
if truncate_dim:
|
582 |
-
|
583 |
-
logger.warning(
|
584 |
-
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
585 |
-
)
|
586 |
-
elif truncate_dim in self.config.matryoshka_dimensions:
|
587 |
-
all_embeddings = [tensor[:truncate_dim] for tensor in all_embeddings]
|
588 |
-
else:
|
589 |
-
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
590 |
-
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
591 |
|
592 |
if convert_to_tensor:
|
593 |
all_embeddings = torch.stack(all_embeddings)
|
@@ -600,6 +592,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
600 |
self.train(is_training)
|
601 |
return all_embeddings
|
602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
def mean_pooling(
|
604 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
605 |
):
|
|
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
580 |
|
581 |
if truncate_dim:
|
582 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
|
584 |
if convert_to_tensor:
|
585 |
all_embeddings = torch.stack(all_embeddings)
|
|
|
592 |
self.train(is_training)
|
593 |
return all_embeddings
|
594 |
|
595 |
+
|
596 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
597 |
+
if not self.config.matryoshka_dimensions:
|
598 |
+
logger.warning(
|
599 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
600 |
+
)
|
601 |
+
return embeddings
|
602 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
603 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
604 |
+
else:
|
605 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
606 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
607 |
+
|
608 |
def mean_pooling(
|
609 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
610 |
):
|