Spaces:
Runtime error
Runtime error
File size: 5,798 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import time
import torch
import random
import torch.nn as nn
import numpy as np
from transformers import AdamW, get_linear_schedule_with_warmup
from colbert.infra import ColBERTConfig
from colbert.training.rerank_batcher import RerankBatcher
from colbert.utils.amp import MixedPrecisionManager
from colbert.training.lazy_batcher import LazyBatcher
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.modeling.reranker.electra import ElectraReranker
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints
def train(config: ColBERTConfig, triples, queries=None, collection=None):
config.checkpoint = config.checkpoint or 'bert-base-uncased'
if config.rank < 1:
config.help()
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)
assert config.bsize % config.nranks == 0, (config.bsize, config.nranks)
config.bsize = config.bsize // config.nranks
print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps)
if collection is not None:
if config.reranker:
reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
raise NotImplementedError()
if not config.reranker:
colbert = ColBERT(name=config.checkpoint, colbert_config=config)
else:
colbert = ElectraReranker.from_pretrained(config.checkpoint)
colbert = colbert.to(DEVICE)
colbert.train()
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank],
output_device=config.rank,
find_unused_parameters=True)
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8)
optimizer.zero_grad()
scheduler = None
if config.warmup is not None:
print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.")
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup,
num_training_steps=config.maxsteps)
warmup_bert = config.warmup_bert
if warmup_bert is not None:
set_bert_grad(colbert, False)
amp = MixedPrecisionManager(config.amp)
labels = torch.zeros(config.bsize, dtype=torch.long, device=DEVICE)
start_time = time.time()
train_loss = None
train_loss_mu = 0.999
start_batch_idx = 0
# if config.resume:
# assert config.checkpoint is not None
# start_batch_idx = checkpoint['batch']
# reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])
for batch_idx, BatchSteps in zip(range(start_batch_idx, config.maxsteps), reader):
if (warmup_bert is not None) and warmup_bert <= batch_idx:
set_bert_grad(colbert, True)
warmup_bert = None
this_batch_loss = 0.0
for batch in BatchSteps:
with amp.context():
try:
queries, passages, target_scores = batch
encoding = [queries, passages]
except:
encoding, target_scores = batch
encoding = [encoding.to(DEVICE)]
scores = colbert(*encoding)
if config.use_ib_negatives:
scores, ib_loss = scores
scores = scores.view(-1, config.nway)
if len(target_scores) and not config.ignore_scores:
target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE)
target_scores = target_scores * config.distillation_alpha
target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1)
log_scores = torch.nn.functional.log_softmax(scores, dim=-1)
loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores)
else:
loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)])
if config.use_ib_negatives:
if config.rank < 1:
print('\t\t\t\t', loss.item(), ib_loss.item())
loss += ib_loss
loss = loss / config.accumsteps
if config.rank < 1:
print_progress(scores)
amp.backward(loss)
this_batch_loss += loss.item()
train_loss = this_batch_loss if train_loss is None else train_loss
train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss
amp.step(colbert, optimizer, scheduler)
if config.rank < 1:
print_message(batch_idx, train_loss)
manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None)
if config.rank < 1:
print_message("#> Done with all triples!")
ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True)
return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one.
def set_bert_grad(colbert, value):
try:
for p in colbert.bert.parameters():
assert p.requires_grad is (not value)
p.requires_grad = value
except AttributeError:
set_bert_grad(colbert.module, value)
|