|
import os
|
|
import torch
|
|
|
|
from colbert.utils.runs import Run
|
|
from colbert.utils.utils import print_message, save_checkpoint
|
|
from colbert.parameters import SAVED_CHECKPOINTS
|
|
|
|
|
|
def print_progress(scores):
|
|
positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2)
|
|
print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg)
|
|
|
|
|
|
def manage_checkpoints(args, colbert, optimizer, batch_idx):
|
|
arguments = args.input_arguments.__dict__
|
|
|
|
path = os.path.join(Run.path, 'checkpoints')
|
|
|
|
if not os.path.exists(path):
|
|
os.mkdir(path)
|
|
|
|
if batch_idx % 2000 == 0:
|
|
name = os.path.join(path, "colbert.dnn")
|
|
save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
|
|
|
|
if batch_idx in SAVED_CHECKPOINTS:
|
|
name = os.path.join(path, "colbert-{}.dnn".format(batch_idx))
|
|
save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
|
|
|