|
import logging |
|
import traceback |
|
|
|
from datasets import load_dataset |
|
|
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
from sentence_transformers.cross_encoder.evaluation.CENanoBEIREvaluator import ( |
|
CENanoBEIREvaluator, |
|
) |
|
from sentence_transformers.cross_encoder.losses import ListNetLoss |
|
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer |
|
from sentence_transformers.cross_encoder.training_args import ( |
|
CrossEncoderTrainingArguments, |
|
) |
|
|
|
|
|
def main(): |
|
model_name = "microsoft/MiniLM-L12-H384-uncased" |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
|
|
|
|
train_batch_size = 8 |
|
num_epochs = 1 |
|
max_docs = 10 |
|
pad_value = -1 |
|
loss_name = "listnet" |
|
num_labels = 1 |
|
|
|
|
|
model = CrossEncoder(model_name, num_labels=num_labels) |
|
print("Model max length:", model.max_length) |
|
print("Model num labels:", model.num_labels) |
|
|
|
|
|
logging.info("Read train dataset") |
|
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train") |
|
|
|
def listwise_mapper(batch, max_docs: int = 10, pad_value: int = -1): |
|
processed_queries = [] |
|
processed_docs = [] |
|
processed_labels = [] |
|
|
|
for query, passages_info in zip(batch["query"], batch["passages"]): |
|
|
|
passages = passages_info["passage_text"] |
|
labels = passages_info["is_selected"] |
|
|
|
|
|
paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
sorted_passages, sorted_labels = zip(*paired) if paired else ([], []) |
|
|
|
|
|
if max(sorted_labels) < 1.0: |
|
continue |
|
|
|
|
|
truncated_passages = list(sorted_passages[:max_docs]) |
|
truncated_labels = list(sorted_labels[:max_docs]) |
|
|
|
|
|
pad_count = max_docs - len(truncated_passages) |
|
processed_docs.append(truncated_passages + [""] * pad_count) |
|
processed_labels.append(truncated_labels + [pad_value] * pad_count) |
|
processed_queries.append(query) |
|
|
|
return { |
|
"query": processed_queries, |
|
"docs": processed_docs, |
|
"labels": processed_labels, |
|
} |
|
|
|
dataset = dataset.map( |
|
lambda batch: listwise_mapper(batch=batch, max_docs=max_docs, pad_value=pad_value), |
|
batched=True, |
|
remove_columns=dataset.column_names, |
|
desc="Processing listwise samples", |
|
) |
|
|
|
dataset = dataset.train_test_split(test_size=10_000) |
|
train_dataset = dataset["train"] |
|
eval_dataset = dataset["test"] |
|
logging.info(train_dataset) |
|
|
|
|
|
loss = ListNetLoss(model, pad_value=pad_value) |
|
|
|
|
|
evaluator = CENanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=train_batch_size) |
|
evaluator(model) |
|
|
|
|
|
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] |
|
run_name = f"reranker-msmarco-v1.1-{short_model_name}-{loss_name}" |
|
args = CrossEncoderTrainingArguments( |
|
|
|
output_dir=f"models/{run_name}", |
|
|
|
num_train_epochs=num_epochs, |
|
per_device_train_batch_size=train_batch_size, |
|
per_device_eval_batch_size=train_batch_size, |
|
learning_rate=2e-5, |
|
warmup_ratio=0.1, |
|
fp16=False, |
|
bf16=True, |
|
|
|
load_best_model_at_end=True, |
|
metric_for_best_model="eval_NanoBEIR_mean_ndcg@10", |
|
|
|
eval_strategy="steps", |
|
eval_steps=1600, |
|
save_strategy="steps", |
|
save_steps=1600, |
|
save_total_limit=2, |
|
logging_steps=200, |
|
logging_first_step=True, |
|
run_name=run_name, |
|
seed=12, |
|
) |
|
|
|
|
|
trainer = CrossEncoderTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
loss=loss, |
|
evaluator=evaluator, |
|
) |
|
trainer.train() |
|
|
|
|
|
evaluator(model) |
|
|
|
|
|
final_output_dir = f"models/{run_name}/final" |
|
model.save_pretrained(final_output_dir) |
|
|
|
|
|
|
|
try: |
|
model.push_to_hub(run_name) |
|
except Exception: |
|
logging.error( |
|
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " |
|
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " |
|
f"and saving it using `model.push_to_hub('{run_name}')`." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|