splade / configs /fine-tune.yaml
fschlatt's picture
update readme
22f52a4
# lightning.pytorch==2.3.3
seed_everything: 0
trainer:
precision: bf16-mixed
max_steps: 50000
data:
class_path: lightning_ir.LightningIRDataModule
init_args:
num_workers: 1
train_batch_size: 64
shuffle_train: true
train_dataset:
class_path: lightning_ir.RunDataset
init_args:
run_path_or_id: msmarco-passage/train/rank-distillm/set-encoder
depth: 100
sample_size: 8
sampling_strategy: log_random
targets: score
normalize_targets: false
model:
class_path: lightning_ir.BiEncoderModule
init_args:
model_name_or_path: bert-base-uncased
config:
class_path: lightning_ir.SpladeConfig
init_args:
query_pooling_strategy: max
doc_pooling_strategy: max
projection: mlm
sparsification: relu_log
embedding_dim: 30522
similarity_function: dot
query_expansion: false
attend_to_query_expanded_tokens: false
query_mask_scoring_tokens: null
query_aggregation_function: sum
doc_expansion: false
attend_to_doc_expanded_tokens: false
doc_mask_scoring_tokens: null
normalize: false
add_marker_tokens: false
query_length: 32
doc_length: 256
loss_functions:
- - class_path: lightning_ir.SupervisedMarginMSE
- 0.05
- class_path: lightning_ir.KLDivergence
- class_path: lightning_ir.FLOPSRegularization
init_args:
query_weight: 0.01
doc_weight: 0.02
- class_path: lightning_ir.InBatchCrossEntropy
init_args:
pos_sampling_technique: first
neg_sampling_technique: first
max_num_neg_samples: 8
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 2.0e-05
lr_scheduler:
class_path: lightning_ir.ConstantLRSchedulerWithLinearWarmup
init_args:
num_warmup_steps: 5000
num_delay_steps: 0