mColBERT / colbert /train.py
vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame contribute delete
963 Bytes
import os
import random
import torch
import copy
import colbert.utils.distributed as distributed
from colbert.utils.parser import Arguments
from colbert.utils.runs import Run
from colbert.training.training import train
def main():
parser = Arguments(description='Training ColBERT with <query, positive passage, negative passage> triples.')
parser.add_model_parameters()
parser.add_model_training_parameters()
parser.add_training_input()
args = parser.parse()
assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps),
"The batch size must be divisible by the number of gradient accumulation steps.")
assert args.query_maxlen <= 512
assert args.doc_maxlen <= 512
args.lazy = args.collection is not None
with Run.context(consider_failed_if_interrupted=False):
train(args)
if __name__ == "__main__":
main()