V-BeachNet / train_video_seg.py
rezasalatin's picture
Upload 8 files
0b70b07 verified
raw
history blame
6.27 kB
import os
import time
import argparse
import numpy as np
from tqdm import tqdm
from glob import glob
import torch
from torch.utils.data import DataLoader
from video_module.dataset import Water_Image_Train_DS
from video_module.model import AFB_URR, FeatureBank
import myutils
# Enable CUDA launch blocking for debugging; set to '0' to disable.
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
def get_args():
parser = argparse.ArgumentParser(description='Train V-BeachNet')
parser.add_argument('--gpu', type=int, default=0, help='GPU card id.')
parser.add_argument('--dataset', type=str, required=True, help='Dataset folder.')
parser.add_argument('--seed', type=int, default=-1, help='Random seed.')
parser.add_argument('--log', action='store_true', help='Save the training results.')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate (default: 1e-5).')
parser.add_argument('--lu', type=float, default=0.5, help='Regularization factor (default: 0.5).')
parser.add_argument('--resume', type=str, help='Path to checkpoint (default: none).')
parser.add_argument('--new', action='store_true', help='Train the model from scratch.')
parser.add_argument('--scheduler-step', type=int, default=25, help='Scheduler step size (default: 25).')
parser.add_argument('--total-epochs', type=int, default=100, help='Total number of epochs (default: 100).')
parser.add_argument('--budget', type=int, default=300000, help='Maximum number of features in the feature bank (default: 300000).')
parser.add_argument('--obj-n', type=int, default=2, help='Maximum number of objects trained simultaneously.')
parser.add_argument('--clip-n', type=int, default=6, help='Maximum number of frames in a batch.')
return parser.parse_args()
def train_model(model, dataloader, criterion, optimizer):
stats = myutils.AvgMeter()
uncertainty_stats = myutils.AvgMeter()
progress_bar = tqdm(dataloader)
for _, sample in enumerate(progress_bar):
frames, masks, obj_n, info = sample
if obj_n.item() == 1:
continue
frames, masks = frames[0].to(device), masks[0].to(device)
fb_global = FeatureBank(obj_n.item(), args.budget, device)
k4_list, v4_list = model.memorize(frames[:1], masks[:1])
fb_global.init_bank(k4_list, v4_list)
scores, uncertainty = model.segment(frames[1:], fb_global)
label = torch.argmax(masks[1:], dim=1).long()
optimizer.zero_grad()
loss = criterion(scores, label) + args.lu * uncertainty
loss.backward()
optimizer.step()
uncertainty_stats.update(uncertainty.item())
stats.update(loss.item())
progress_bar.set_postfix(loss=f'{loss.item():.5f} (Avg: {stats.avg:.5f}, Uncertainty Avg: {uncertainty_stats.avg:.5f})')
return stats.avg
def main():
dataset = Water_Image_Train_DS(root=args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
print(myutils.gct(), f'Dataset with {len(dataset)} training cases.')
model = AFB_URR(device, update_bank=False, load_imagenet_params=True).to(device)
model.train()
model.apply(myutils.set_bn_eval)
optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), args.lr)
start_epoch, best_loss = 0, float('inf')
if args.resume:
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model'], strict=False)
seed = checkpoint.get('seed', int(time.time()))
if not args.new:
start_epoch = checkpoint['epoch'] + 1
optimizer.load_state_dict(checkpoint['optimizer'])
best_loss = checkpoint['loss']
print(myutils.gct(), f'Resumed from checkpoint {args.resume} (Epoch: {start_epoch-1}, Best Loss: {best_loss}).')
else:
print(myutils.gct(), f'Loaded checkpoint {args.resume}. Training from scratch.')
else:
raise FileNotFoundError(f'No checkpoint found at {args.resume}')
else:
seed = args.seed if args.seed >= 0 else int(time.time())
print(myutils.gct(), 'Random seed:', seed)
torch.manual_seed(seed)
np.random.seed(seed)
criterion = torch.nn.CrossEntropyLoss().to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.5, last_epoch=start_epoch-1)
for epoch in range(start_epoch, args.total_epochs):
print(f'\n{myutils.gct()} Epoch: {epoch}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
loss = train_model(model, dataloader, criterion, optimizer)
if args.log:
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss,
'seed': seed
}
torch.save(checkpoint, os.path.join(model_path, 'final.pth'))
if loss < best_loss:
best_loss = loss
torch.save(checkpoint, os.path.join(model_path, 'best.pth'))
print('Updated best model.')
scheduler.step()
if __name__ == '__main__':
args = get_args()
print(myutils.gct(), f'Arguments: {args}')
if args.gpu >= 0 and torch.cuda.is_available():
device = torch.device('cuda', args.gpu)
else:
raise ValueError('CUDA is required. Ensure --gpu is set to >= 0.')
if args.log:
log_dir = os.path.join('logs', time.strftime('%Y%m%d-%H%M%S'))
model_path = os.path.join(log_dir, 'model')
os.makedirs(model_path, exist_ok=True)
myutils.save_scripts(log_dir, scripts_to_save=glob('*.*'))
myutils.save_scripts(log_dir, scripts_to_save=glob('dataset/*.py', recursive=True))
myutils.save_scripts(log_dir, scripts_to_save=glob('model/*.py', recursive=True))
myutils.save_scripts(log_dir, scripts_to_save=glob('myutils/*.py', recursive=True))
print(myutils.gct(), f'Created log directory: {log_dir}')
main()
print(myutils.gct(), 'Training completed.')