vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame contribute delete
984 Bytes
import os
import torch
from colbert.utils.runs import Run
from colbert.utils.utils import print_message, save_checkpoint
from colbert.parameters import SAVED_CHECKPOINTS
def print_progress(scores):
positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2)
print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg)
def manage_checkpoints(args, colbert, optimizer, batch_idx):
arguments = args.input_arguments.__dict__
path = os.path.join(Run.path, 'checkpoints')
if not os.path.exists(path):
os.mkdir(path)
if batch_idx % 2000 == 0:
name = os.path.join(path, "colbert.dnn")
save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
if batch_idx in SAVED_CHECKPOINTS:
name = os.path.join(path, "colbert-{}.dnn".format(batch_idx))
save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)