Update README.md
Browse files
README.md
CHANGED
@@ -2684,7 +2684,7 @@ print(embeddings)
|
|
2684 |
|
2685 |
### Transformers
|
2686 |
|
2687 |
-
```
|
2688 |
import torch
|
2689 |
import torch.nn.functional as F
|
2690 |
from transformers import AutoTokenizer, AutoModel
|
@@ -2702,10 +2702,13 @@ model.eval()
|
|
2702 |
|
2703 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
2704 |
|
|
|
|
|
2705 |
with torch.no_grad():
|
2706 |
model_output = model(**encoded_input)
|
2707 |
|
2708 |
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
2709 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
2710 |
print(embeddings)
|
2711 |
```
|
|
|
2684 |
|
2685 |
### Transformers
|
2686 |
|
2687 |
+
```diff
|
2688 |
import torch
|
2689 |
import torch.nn.functional as F
|
2690 |
from transformers import AutoTokenizer, AutoModel
|
|
|
2702 |
|
2703 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
2704 |
|
2705 |
+
+ matryoshka_dim = 512
|
2706 |
+
|
2707 |
with torch.no_grad():
|
2708 |
model_output = model(**encoded_input)
|
2709 |
|
2710 |
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
2711 |
+
+ embeddings = embeddings[:, :matryoshka_dim]
|
2712 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
2713 |
print(embeddings)
|
2714 |
```
|