alphakotenok commited on
Commit
4226750
·
verified ·
1 Parent(s): e00a879

Update modeling_zeranker.py

Browse files
Files changed (1) hide show
  1. modeling_zeranker.py +39 -13
modeling_zeranker.py CHANGED
@@ -1,7 +1,7 @@
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
4
- from typing import cast
5
  import types
6
 
7
  import torch
@@ -21,8 +21,11 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
  # pyright: reportUnknownMemberType=false
22
  # pyright: reportUnknownVariableType=false
23
 
24
- MODEL_PATH = "zeroentropy/ze-rerank-large-v0.3.0"
25
  PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
 
 
 
26
 
27
 
28
  def format_pointwise_datapoints(
@@ -67,7 +70,7 @@ def load_model(
67
  | Qwen3ForCausalLM,
68
  ]:
69
  if device is None:
70
- device = torch.device("cpu")
71
 
72
  config = AutoConfig.from_pretrained(MODEL_PATH)
73
  assert isinstance(config, PretrainedConfig)
@@ -80,7 +83,6 @@ def load_model(
80
  )
81
  if config.model_type == "llama":
82
  model.config.attn_implementation = "flash_attention_2"
83
- print(f"Model Type: {config.model_type}")
84
  assert isinstance(
85
  model,
86
  LlamaForCausalLM
@@ -104,13 +106,30 @@ def load_model(
104
  return tokenizer, model
105
 
106
 
107
- def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if not hasattr(self, "inner_model"):
109
- self.inner_tokenizer, self.inner_model = load_model(torch.device("cuda"))
110
  self.inner_model.gradient_checkpointing_enable()
111
  self.inner_model.eval()
112
- self.inner_yes_token_id = self.inner_tokenizer.encode("Yes", add_special_tokens=False)[0]
113
- print("patched")
 
114
 
115
  model = self.inner_model
116
  tokenizer = self.inner_tokenizer
@@ -120,11 +139,11 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
120
  ]
121
  # Sort
122
  permutation = list(range(len(query_documents)))
123
- permutation.sort(key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1]))
 
 
124
  query_documents = [query_documents[i] for i in permutation]
125
 
126
- device = torch.device("cuda")
127
-
128
  # Extract document batches from this line of datapoints
129
  max_length = 0
130
  batches: list[list[tuple[str, str]]] = []
@@ -148,7 +167,7 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
148
  batch,
149
  )
150
 
151
- batch_inputs = batch_inputs.to(device)
152
 
153
  try:
154
  outputs = model(**batch_inputs, use_cache=False)
@@ -164,7 +183,7 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
164
  last_positions = attention_mask.sum(dim=1) - 1
165
 
166
  batch_size = logits.shape[0]
167
- batch_indices = torch.arange(batch_size, device=device)
168
  last_logits = logits[batch_indices, last_positions]
169
 
170
  yes_logits = last_logits[:, self.inner_yes_token_id]
@@ -181,8 +200,15 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
181
  return scores
182
 
183
 
 
 
 
 
 
184
  _CE.predict = predict
185
 
186
  from transformers import Qwen3Config
187
 
188
  ZEConfig = Qwen3Config
 
 
 
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
4
+ from typing import cast, Any
5
  import types
6
 
7
  import torch
 
21
  # pyright: reportUnknownMemberType=false
22
  # pyright: reportUnknownVariableType=false
23
 
24
+ MODEL_PATH = "zeroentropy/zerank-1"
25
  PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
26
+ global_device = (
27
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
28
+ )
29
 
30
 
31
  def format_pointwise_datapoints(
 
70
  | Qwen3ForCausalLM,
71
  ]:
72
  if device is None:
73
+ device = global_device
74
 
75
  config = AutoConfig.from_pretrained(MODEL_PATH)
76
  assert isinstance(config, PretrainedConfig)
 
83
  )
84
  if config.model_type == "llama":
85
  model.config.attn_implementation = "flash_attention_2"
 
86
  assert isinstance(
87
  model,
88
  LlamaForCausalLM
 
106
  return tokenizer, model
107
 
108
 
109
+ def predict(
110
+ self,
111
+ query_documents: list[tuple[str, str]] | None = None,
112
+ *,
113
+ sentences: Any = None,
114
+ batch_size: Any = None,
115
+ show_progress_bar: Any = None,
116
+ activation_fn: Any = None,
117
+ apply_softmax: Any = None,
118
+ convert_to_numpy: Any = None,
119
+ convert_to_tensor: Any = None,
120
+ ) -> list[float]:
121
+ if query_documents is None:
122
+ if sentences is None:
123
+ raise ValueError("query_documents or sentences must be provided")
124
+ query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
125
+
126
  if not hasattr(self, "inner_model"):
127
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
128
  self.inner_model.gradient_checkpointing_enable()
129
  self.inner_model.eval()
130
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
131
+ "Yes", add_special_tokens=False
132
+ )[0]
133
 
134
  model = self.inner_model
135
  tokenizer = self.inner_tokenizer
 
139
  ]
140
  # Sort
141
  permutation = list(range(len(query_documents)))
142
+ permutation.sort(
143
+ key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
144
+ )
145
  query_documents = [query_documents[i] for i in permutation]
146
 
 
 
147
  # Extract document batches from this line of datapoints
148
  max_length = 0
149
  batches: list[list[tuple[str, str]]] = []
 
167
  batch,
168
  )
169
 
170
+ batch_inputs = batch_inputs.to(global_device)
171
 
172
  try:
173
  outputs = model(**batch_inputs, use_cache=False)
 
183
  last_positions = attention_mask.sum(dim=1) - 1
184
 
185
  batch_size = logits.shape[0]
186
+ batch_indices = torch.arange(batch_size, device=global_device)
187
  last_logits = logits[batch_indices, last_positions]
188
 
189
  yes_logits = last_logits[:, self.inner_yes_token_id]
 
200
  return scores
201
 
202
 
203
+ def to_device(self: _CE, new_device: torch.device) -> None:
204
+ global global_device
205
+ global_device = new_device
206
+
207
+
208
  _CE.predict = predict
209
 
210
  from transformers import Qwen3Config
211
 
212
  ZEConfig = Qwen3Config
213
+
214
+ _CE.to = to_device