jxm commited on
Commit
6ab272b
·
verified ·
1 Parent(s): 10f0819

Integrate with Sentence Transformers (#3)

Browse files

- Integrate with Sentence Transformers + README (bed6830fdf35eb145747402f8dea5803018d07ab)
- Bump up minimum version (aea6a04d99a95e89ffe0b4f5fc20dcd6e75fa85a)
- Replace local-only "." with "jxm/cde-small-v1" (9677008ed99e455a9026e1ea0062fe2cbaf73de5)

README.md CHANGED
@@ -1,6 +1,8 @@
1
  ---
2
  tags:
3
  - mteb
 
 
4
  model-index:
5
  - name: cde-small-v1
6
  results:
@@ -8660,8 +8662,184 @@ Our new model that naturally integrates "context tokens" into the embedding proc
8660
 
8661
  Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
8662
 
 
8663
 
8664
- ## Loading the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8665
 
8666
  Our model can be loaded using `transformers` out-of-the-box with "trust remote code" enabled. We use the default BERT uncased tokenizer:
8667
  ```python
@@ -8680,7 +8858,7 @@ query_prefix = "search_query: "
8680
  document_prefix = "search_document: "
8681
  ```
8682
 
8683
- ## First stage
8684
 
8685
  ```python
8686
  minicorpus_size = model.config.transductive_corpus_size
@@ -8692,7 +8870,7 @@ minicorpus_docs = tokenizer(
8692
  padding=True,
8693
  max_length=512,
8694
  return_tensors="pt"
8695
- )
8696
  import torch
8697
  from tqdm.autonotebook import tqdm
8698
 
@@ -8709,7 +8887,7 @@ for i in tqdm(range(0, len(minicorpus_docs["input_ids"]), batch_size)):
8709
  dataset_embeddings = torch.cat(dataset_embeddings)
8710
  ```
8711
 
8712
- ## Running the second stage
8713
 
8714
  Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
8715
  ```python
@@ -8719,7 +8897,7 @@ docs = tokenizer(
8719
  padding=True,
8720
  max_length=512,
8721
  return_tensors="pt"
8722
- ).to(device)
8723
 
8724
  with torch.no_grad():
8725
  doc_embeddings = model.second_stage_model(
@@ -8739,7 +8917,7 @@ queries = tokenizer(
8739
  padding=True,
8740
  max_length=512,
8741
  return_tensors="pt"
8742
- ).to(device)
8743
 
8744
  with torch.no_grad():
8745
  query_embeddings = model.second_stage_model(
@@ -8752,6 +8930,8 @@ query_embeddings /= query_embeddings.norm(p=2, dim=1, keepdim=True)
8752
 
8753
  these embeddings can be compared using dot product, since they're normalized.
8754
 
 
 
8755
  ### What if I don't know what my corpus will be ahead of time?
8756
 
8757
  If you can't obtain corpus information ahead of time, you still have to pass *something* as the dataset embeddings; our model will work fine in this case, but not quite as well; without corpus information, our model performance drops from 65.0 to 63.8 on MTEB. We provide [some random strings](https://huggingface.co/jxm/cde-small-v1/resolve/main/random_strings.txt) that worked well for us that can be used as a substitute for corpus sampling.
 
1
  ---
2
  tags:
3
  - mteb
4
+ - transformers
5
+ - sentence-transformers
6
  model-index:
7
  - name: cde-small-v1
8
  results:
 
8662
 
8663
  Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
8664
 
8665
+ ## With Sentence Transformers
8666
 
8667
+ <details open="">
8668
+ <summary>Click to learn how to use cde-small-v1 with Sentence Transformers</summary>
8669
+
8670
+ ### Loading the model
8671
+
8672
+ Our model can be loaded using `sentence-transformers` out-of-the-box with "trust remote code" enabled:
8673
+ ```python
8674
+ from sentence_transformers import SentenceTransformer
8675
+
8676
+ model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
8677
+ ```
8678
+
8679
+ #### Note on prefixes
8680
+
8681
+ *Nota bene*: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can use `prompt_name="query"` and `prompt_name="document"` in the `encode` method of the model when embedding queries and documents, respectively.
8682
+
8683
+ ### First stage
8684
+
8685
+ ```python
8686
+ minicorpus_size = model[0].config.transductive_corpus_size
8687
+ minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
8688
+ assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
8689
+
8690
+ dataset_embeddings = model.encode(
8691
+ minicorpus_docs,
8692
+ prompt_name="document",
8693
+ convert_to_tensor=True
8694
+ )
8695
+ ```
8696
+
8697
+ ### Running the second stage
8698
+
8699
+ Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prompt for documents:
8700
+
8701
+ ```python
8702
+ docs = [...]
8703
+ queries = [...]
8704
+
8705
+ doc_embeddings = model.encode(
8706
+ docs,
8707
+ prompt_name="document",
8708
+ dataset_embeddings=dataset_embeddings,
8709
+ convert_to_tensor=True,
8710
+ )
8711
+ query_embeddings = model.encode(
8712
+ queries,
8713
+ prompt_name="query",
8714
+ dataset_embeddings=dataset_embeddings,
8715
+ convert_to_tensor=True,
8716
+ )
8717
+ ```
8718
+
8719
+ these embeddings can be compared using cosine similarity via `model.similarity`:
8720
+ ```python
8721
+ similarities = model.similarity(query_embeddings, doc_embeddings)
8722
+ topk_values, topk_indices = similarities.topk(5)
8723
+ ```
8724
+
8725
+ <details>
8726
+ <summary>Click here for a full copy-paste ready example</summary>
8727
+
8728
+ ```python
8729
+ from sentence_transformers import SentenceTransformer
8730
+ from datasets import load_dataset
8731
+
8732
+ # 1. Load the Sentence Transformer model
8733
+ model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
8734
+ context_docs_size = model[0].config.transductive_corpus_size # 512
8735
+
8736
+ # 2. Load the dataset: context dataset, docs, and queries
8737
+ dataset = load_dataset("sentence-transformers/natural-questions", split="train")
8738
+ dataset.shuffle(seed=42)
8739
+ # 10 queries, 512 context docs, 500 docs
8740
+ queries = dataset["query"][:10]
8741
+ docs = dataset["answer"][:2000]
8742
+ context_docs = dataset["answer"][-context_docs_size:] # Last 512 docs
8743
+
8744
+ # 3. First stage: embed the context docs
8745
+ dataset_embeddings = model.encode(
8746
+ context_docs,
8747
+ prompt_name="document",
8748
+ convert_to_tensor=True,
8749
+ )
8750
+
8751
+ # 4. Second stage: embed the docs and queries
8752
+ doc_embeddings = model.encode(
8753
+ docs,
8754
+ prompt_name="document",
8755
+ dataset_embeddings=dataset_embeddings,
8756
+ convert_to_tensor=True,
8757
+ )
8758
+ query_embeddings = model.encode(
8759
+ queries,
8760
+ prompt_name="query",
8761
+ dataset_embeddings=dataset_embeddings,
8762
+ convert_to_tensor=True,
8763
+ )
8764
+
8765
+ # 5. Compute the similarity between the queries and docs
8766
+ similarities = model.similarity(query_embeddings, doc_embeddings)
8767
+ topk_values, topk_indices = similarities.topk(5)
8768
+ print(topk_values)
8769
+ print(topk_indices)
8770
+
8771
+ """
8772
+ tensor([[0.5495, 0.5426, 0.5423, 0.5292, 0.5286],
8773
+ [0.6357, 0.6334, 0.6177, 0.5862, 0.5794],
8774
+ [0.7648, 0.5452, 0.5000, 0.4959, 0.4881],
8775
+ [0.6802, 0.5225, 0.5178, 0.5160, 0.5075],
8776
+ [0.6947, 0.5843, 0.5619, 0.5344, 0.5298],
8777
+ [0.7742, 0.7742, 0.7742, 0.7231, 0.6224],
8778
+ [0.8853, 0.6667, 0.5829, 0.5795, 0.5769],
8779
+ [0.6911, 0.6127, 0.6003, 0.5986, 0.5936],
8780
+ [0.6796, 0.6053, 0.6000, 0.5911, 0.5884],
8781
+ [0.7624, 0.5589, 0.5428, 0.5278, 0.5275]], device='cuda:0')
8782
+ tensor([[ 0, 296, 234, 1651, 1184],
8783
+ [1542, 466, 438, 1207, 1911],
8784
+ [ 2, 1562, 632, 1852, 382],
8785
+ [ 3, 694, 932, 1765, 662],
8786
+ [ 4, 35, 747, 26, 432],
8787
+ [ 534, 175, 5, 1495, 575],
8788
+ [ 6, 1802, 1875, 747, 21],
8789
+ [ 7, 1913, 1936, 640, 6],
8790
+ [ 8, 747, 167, 1318, 1743],
8791
+ [ 9, 1583, 1145, 219, 357]], device='cuda:0')
8792
+ """
8793
+ # As you can see, almost every query_i has document_i as the most similar document.
8794
+
8795
+ # 6. Print the top-k results
8796
+ for query_idx, top_doc_idx in enumerate(topk_indices[:, 0]):
8797
+ print(f"Query {query_idx}: {queries[query_idx]}")
8798
+ print(f"Top Document: {docs[top_doc_idx]}")
8799
+ print()
8800
+ """
8801
+ Query 0: when did richmond last play in a preliminary final
8802
+ Top Document: Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.
8803
+
8804
+ Query 1: who sang what in the world's come over you
8805
+ Top Document: Life's What You Make It (Talk Talk song) "Life's What You Make It" is a song by the English band Talk Talk. It was released as a single in 1986, the first from the band's album The Colour of Spring. The single was a hit in the UK, peaking at No. 16, and charted in numerous other countries, often reaching the Top 20.
8806
+
8807
+ Query 2: who produces the most wool in the world
8808
+ Top Document: Wool Global wool production is about 2 million tonnes per year, of which 60% goes into apparel. Wool comprises ca 3% of the global textile market, but its value is higher owing to dying and other modifications of the material.[1] Australia is a leading producer of wool which is mostly from Merino sheep but has been eclipsed by China in terms of total weight.[30] New Zealand (2016) is the third-largest producer of wool, and the largest producer of crossbred wool. Breeds such as Lincoln, Romney, Drysdale, and Elliotdale produce coarser fibers, and wool from these sheep is usually used for making carpets.
8809
+
8810
+ Query 3: where does alaska the last frontier take place
8811
+ Top Document: Alaska: The Last Frontier Alaska: The Last Frontier is an American reality cable television series on the Discovery Channel, currently in its 7th season of broadcast. The show documents the extended Kilcher family, descendants of Swiss immigrants and Alaskan pioneers, Yule and Ruth Kilcher, at their homestead 11 miles outside of Homer.[1] By living without plumbing or modern heating, the clan chooses to subsist by farming, hunting and preparing for the long winters.[2] The Kilcher family are relatives of the singer Jewel,[1][3] who has appeared on the show.[4]
8812
+
8813
+ Query 4: a day to remember all i want cameos
8814
+ Top Document: All I Want (A Day to Remember song) The music video for the song, which was filmed in October 2010,[4] was released on January 6, 2011.[5] It features cameos of numerous popular bands and musicians. The cameos are: Tom Denney (A Day to Remember's former guitarist), Pete Wentz, Winston McCall of Parkway Drive, The Devil Wears Prada, Bring Me the Horizon, Sam Carter of Architects, Tim Lambesis of As I Lay Dying, Silverstein, Andrew WK, August Burns Red, Seventh Star, Matt Heafy of Trivium, Vic Fuentes of Pierce the Veil, Mike Herrera of MxPx, and Set Your Goals.[5] Rock Sound called the video "quite excellent".[5]
8815
+
8816
+ Query 5: what does the red stripes mean on the american flag
8817
+ Top Document: Flag of the United States The flag of the United States of America, often referred to as the American flag, is the national flag of the United States. It consists of thirteen equal horizontal stripes of red (top and bottom) alternating with white, with a blue rectangle in the canton (referred to specifically as the "union") bearing fifty small, white, five-pointed stars arranged in nine offset horizontal rows, where rows of six stars (top and bottom) alternate with rows of five stars. The 50 stars on the flag represent the 50 states of the United States of America, and the 13 stripes represent the thirteen British colonies that declared independence from the Kingdom of Great Britain, and became the first states in the U.S.[1] Nicknames for the flag include The Stars and Stripes,[2] Old Glory,[3] and The Star-Spangled Banner.
8818
+
8819
+ Query 6: where did they film diary of a wimpy kid
8820
+ Top Document: Diary of a Wimpy Kid (film) Filming of Diary of a Wimpy Kid was in Vancouver and wrapped up on October 16, 2009.
8821
+
8822
+ Query 7: where was beasts of the southern wild filmed
8823
+ Top Document: Beasts of the Southern Wild The film's fictional setting, "Isle de Charles Doucet", known to its residents as the Bathtub, was inspired by several isolated and independent fishing communities threatened by erosion, hurricanes and rising sea levels in Louisiana's Terrebonne Parish, most notably the rapidly eroding Isle de Jean Charles. It was filmed in Terrebonne Parish town Montegut.[5]
8824
+
8825
+ Query 8: what part of the country are you likely to find the majority of the mollisols
8826
+ Top Document: Mollisol Mollisols occur in savannahs and mountain valleys (such as Central Asia, or the North American Great Plains). These environments have historically been strongly influenced by fire and abundant pedoturbation from organisms such as ants and earthworms. It was estimated that in 2003, only 14 to 26 percent of grassland ecosystems still remained in a relatively natural state (that is, they were not used for agriculture due to the fertility of the A horizon). Globally, they represent ~7% of ice-free land area. As the world's most agriculturally productive soil order, the Mollisols represent one of the more economically important soil orders.
8827
+
8828
+ Query 9: when did fosters home for imaginary friends start
8829
+ Top Document: Foster's Home for Imaginary Friends McCracken conceived the series after adopting two dogs from an animal shelter and applying the concept to imaginary friends. The show first premiered on Cartoon Network on August 13, 2004, as a 90-minute television film. On August 20, it began its normal run of twenty-to-thirty-minute episodes on Fridays, at 7 pm. The series finished its run on May 3, 2009, with a total of six seasons and seventy-nine episodes. McCracken left Cartoon Network shortly after the series ended. Reruns have aired on Boomerang from August 11, 2012 to November 3, 2013 and again from June 1, 2014 to April 3, 2017.
8830
+ """
8831
+ ```
8832
+
8833
+ </details>
8834
+
8835
+ </details>
8836
+
8837
+ ## With Transformers
8838
+
8839
+ <details>
8840
+ <summary>Click to learn how to use cde-small-v1 with Transformers</summary>
8841
+
8842
+ ### Loading the model
8843
 
8844
  Our model can be loaded using `transformers` out-of-the-box with "trust remote code" enabled. We use the default BERT uncased tokenizer:
8845
  ```python
 
8858
  document_prefix = "search_document: "
8859
  ```
8860
 
8861
+ ### First stage
8862
 
8863
  ```python
8864
  minicorpus_size = model.config.transductive_corpus_size
 
8870
  padding=True,
8871
  max_length=512,
8872
  return_tensors="pt"
8873
+ ).to(model.device)
8874
  import torch
8875
  from tqdm.autonotebook import tqdm
8876
 
 
8887
  dataset_embeddings = torch.cat(dataset_embeddings)
8888
  ```
8889
 
8890
+ ### Running the second stage
8891
 
8892
  Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
8893
  ```python
 
8897
  padding=True,
8898
  max_length=512,
8899
  return_tensors="pt"
8900
+ ).to(model.device)
8901
 
8902
  with torch.no_grad():
8903
  doc_embeddings = model.second_stage_model(
 
8917
  padding=True,
8918
  max_length=512,
8919
  return_tensors="pt"
8920
+ ).to(model.device)
8921
 
8922
  with torch.no_grad():
8923
  query_embeddings = model.second_stage_model(
 
8930
 
8931
  these embeddings can be compared using dot product, since they're normalized.
8932
 
8933
+ </details>
8934
+
8935
  ### What if I don't know what my corpus will be ahead of time?
8936
 
8937
  If you can't obtain corpus information ahead of time, you still have to pass *something* as the dataset embeddings; our model will work fine in this case, but not quite as well; without corpus information, our model performance drops from 65.0 to 63.8 on MTEB. We provide [some random strings](https://huggingface.co/jxm/cde-small-v1/resolve/main/random_strings.txt) that worked well for us that can be used as a substitute for corpus sampling.
config_sentence_transformers.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.1.0",
4
+ "transformers": "4.43.4",
5
+ "pytorch": "2.5.0.dev20240807+cu121"
6
+ },
7
+ "prompts": {
8
+ "query": "search_query: ",
9
+ "document": "search_document: "
10
+ },
11
+ "default_prompt_name": null,
12
+ "similarity_fn_name": "cosine"
13
+ }
modules.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers_impl.Transformer",
7
+ "kwargs": ["dataset_embeddings"]
8
+ }
9
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
sentence_transformers_impl.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+ import torch
9
+ from torch import nn
10
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Transformer(nn.Module):
16
+ """Hugging Face AutoModel to generate token embeddings.
17
+ Loads the correct class, e.g. BERT / RoBERTa etc.
18
+
19
+ Args:
20
+ model_name_or_path: Hugging Face models name
21
+ (https://huggingface.co/models)
22
+ max_seq_length: Truncate any inputs longer than max_seq_length
23
+ model_args: Keyword arguments passed to the Hugging Face
24
+ Transformers model
25
+ tokenizer_args: Keyword arguments passed to the Hugging Face
26
+ Transformers tokenizer
27
+ config_args: Keyword arguments passed to the Hugging Face
28
+ Transformers config
29
+ cache_dir: Cache dir for Hugging Face Transformers to store/load
30
+ models
31
+ do_lower_case: If true, lowercases the input (independent if the
32
+ model is cased or not)
33
+ tokenizer_name_or_path: Name or path of the tokenizer. When
34
+ None, then model_name_or_path is used
35
+ backend: Backend used for model inference. Can be `torch`, `onnx`,
36
+ or `openvino`. Default is `torch`.
37
+ """
38
+
39
+ save_in_root: bool = True
40
+
41
+ def __init__(
42
+ self,
43
+ model_name_or_path: str,
44
+ model_args: dict[str, Any] | None = None,
45
+ tokenizer_args: dict[str, Any] | None = None,
46
+ config_args: dict[str, Any] | None = None,
47
+ cache_dir: str | None = None,
48
+ **kwargs,
49
+ ) -> None:
50
+ super().__init__()
51
+ if model_args is None:
52
+ model_args = {}
53
+ if tokenizer_args is None:
54
+ tokenizer_args = {}
55
+ if config_args is None:
56
+ config_args = {}
57
+
58
+ if not model_args.get("trust_remote_code", False):
59
+ raise ValueError(
60
+ "You need to set `trust_remote_code=True` to load this model."
61
+ )
62
+
63
+ self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
64
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
65
+
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ "bert-base-uncased",
68
+ cache_dir=cache_dir,
69
+ **tokenizer_args,
70
+ )
71
+
72
+ def __repr__(self) -> str:
73
+ return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} "
74
+
75
+ def forward(self, features: dict[str, torch.Tensor], dataset_embeddings: Optional[torch.Tensor] = None, **kwargs) -> dict[str, torch.Tensor]:
76
+ """Returns token_embeddings, cls_token"""
77
+ # If we don't have embeddings, then run the 1st stage model.
78
+ # If we do, then run the 2nd stage model.
79
+ if dataset_embeddings is None:
80
+ sentence_embedding = self.auto_model.first_stage_model(
81
+ input_ids=features["input_ids"],
82
+ attention_mask=features["attention_mask"],
83
+ )
84
+ else:
85
+ sentence_embedding = self.auto_model.second_stage_model(
86
+ input_ids=features["input_ids"],
87
+ attention_mask=features["attention_mask"],
88
+ dataset_embeddings=dataset_embeddings,
89
+ )
90
+
91
+ features["sentence_embedding"] = sentence_embedding
92
+ return features
93
+
94
+ def get_word_embedding_dimension(self) -> int:
95
+ return self.auto_model.config.hidden_size
96
+
97
+ def tokenize(
98
+ self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
99
+ ) -> dict[str, torch.Tensor]:
100
+ """Tokenizes a text and maps tokens to token-ids"""
101
+ output = {}
102
+ if isinstance(texts[0], str):
103
+ to_tokenize = [texts]
104
+ elif isinstance(texts[0], dict):
105
+ to_tokenize = []
106
+ output["text_keys"] = []
107
+ for lookup in texts:
108
+ text_key, text = next(iter(lookup.items()))
109
+ to_tokenize.append(text)
110
+ output["text_keys"].append(text_key)
111
+ to_tokenize = [to_tokenize]
112
+ else:
113
+ batch1, batch2 = [], []
114
+ for text_tuple in texts:
115
+ batch1.append(text_tuple[0])
116
+ batch2.append(text_tuple[1])
117
+ to_tokenize = [batch1, batch2]
118
+
119
+ max_seq_length = self.config.max_seq_length
120
+ output.update(
121
+ self.tokenizer(
122
+ *to_tokenize,
123
+ padding=padding,
124
+ truncation="longest_first",
125
+ return_tensors="pt",
126
+ max_length=max_seq_length,
127
+ )
128
+ )
129
+ return output
130
+
131
+ def get_config_dict(self) -> dict[str, Any]:
132
+ return {}
133
+
134
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
135
+ self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
136
+ self.tokenizer.save_pretrained(output_path)
137
+
138
+ with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
139
+ json.dump(self.get_config_dict(), fOut, indent=2)
140
+
141
+ @classmethod
142
+ def load(cls, input_path: str) -> Transformer:
143
+ sbert_config_path = os.path.join(input_path, "sentence_bert_config.json")
144
+ if not os.path.exists(sbert_config_path):
145
+ return cls(model_name_or_path=input_path)
146
+
147
+ with open(sbert_config_path) as fIn:
148
+ config = json.load(fIn)
149
+ # Don't allow configs to set trust_remote_code
150
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
151
+ config["model_args"].pop("trust_remote_code")
152
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
153
+ config["tokenizer_args"].pop("trust_remote_code")
154
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
155
+ config["config_args"].pop("trust_remote_code")
156
+ return cls(model_name_or_path=input_path, **config)