File size: 4,721 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import random
import time
import torch
import torch.nn as nn
import numpy as np
from transformers import AdamW
from colbert.utils.runs import Run
from colbert.utils.amp import MixedPrecisionManager
from colbert.training.lazy_batcher import LazyBatcher
from colbert.training.eager_batcher import EagerBatcher
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints
def train(args):
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
if args.distributed:
torch.cuda.manual_seed_all(12345)
if args.distributed:
assert args.bsize % args.nranks == 0, (args.bsize, args.nranks)
assert args.accumsteps == 1
args.bsize = args.bsize // args.nranks
print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps)
if args.lazy:
reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)
else:
reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks)
if args.rank not in [-1, 0]:
torch.distributed.barrier()
colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
if args.checkpoint is not None:
assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too."
print_message(f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!")
checkpoint = torch.load(args.checkpoint, map_location='cpu')
try:
colbert.load_state_dict(checkpoint['model_state_dict'])
except:
print_message("[WARNING] Loading checkpoint with strict=False")
colbert.load_state_dict(checkpoint['model_state_dict'], strict=False)
if args.rank == 0:
torch.distributed.barrier()
colbert = colbert.to(DEVICE)
colbert.train()
if args.distributed:
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank],
output_device=args.rank,
find_unused_parameters=True)
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8)
optimizer.zero_grad()
amp = MixedPrecisionManager(args.amp)
criterion = nn.CrossEntropyLoss()
labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE)
start_time = time.time()
train_loss = 0.0
start_batch_idx = 0
if args.resume:
assert args.checkpoint is not None
start_batch_idx = checkpoint['batch']
reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])
for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader):
this_batch_loss = 0.0
for queries, passages in BatchSteps:
with amp.context():
scores = colbert(queries, passages).view(2, -1).permute(1, 0)
loss = criterion(scores, labels[:scores.size(0)])
loss = loss / args.accumsteps
if args.rank < 1:
print_progress(scores)
amp.backward(loss)
train_loss += loss.item()
this_batch_loss += loss.item()
amp.step(colbert, optimizer)
if args.rank < 1:
avg_loss = train_loss / (batch_idx+1)
num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks
elapsed = float(time.time() - start_time)
log_to_mlflow = (batch_idx % 20 == 0)
Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow)
Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow)
print_message(batch_idx, avg_loss)
manage_checkpoints(args, colbert, optimizer, batch_idx+1)
|