Spaces:
Sleeping
Sleeping
import os | |
import os.path as osp | |
from copy import deepcopy | |
from collections import OrderedDict | |
import glob | |
from datetime import datetime | |
import random | |
import copy | |
import imageio | |
import torch | |
# import clip | |
import torchvision.transforms.functional as tvf | |
import video3d.utils.meters as meters | |
import video3d.utils.misc as misc | |
# from video3d.dataloaders import get_image_loader | |
from video3d.dataloaders_ddp import get_sequence_loader_ddp, get_sequence_loader_quadrupeds, get_test_loader_quadrupeds | |
from . import discriminator_architecture | |
def sample_frames(batch, num_sample_frames, iteration, stride=1): | |
## window slicing sampling | |
images, masks, flows, bboxs, bg_image, seq_idx, frame_idx = batch | |
num_seqs, total_num_frames = images.shape[:2] | |
# start_frame_idx = iteration % (total_num_frames - num_sample_frames +1) | |
## forward and backward | |
num_windows = total_num_frames - num_sample_frames +1 | |
start_frame_idx = (iteration * stride) % (2*num_windows) | |
## x' = (2n-1)/2 - |(2n-1)/2 - x| : 0,1,2,3,4,5 -> 0,1,2,2,1,0 | |
mid_val = (2*num_windows -1) /2 | |
start_frame_idx = int(mid_val - abs(mid_val -start_frame_idx)) | |
new_batch = images[:, start_frame_idx:start_frame_idx+num_sample_frames], \ | |
masks[:, start_frame_idx:start_frame_idx+num_sample_frames], \ | |
flows[:, start_frame_idx:start_frame_idx+num_sample_frames-1], \ | |
bboxs[:, start_frame_idx:start_frame_idx+num_sample_frames], \ | |
bg_image, \ | |
seq_idx, \ | |
frame_idx[:, start_frame_idx:start_frame_idx+num_sample_frames] | |
return new_batch | |
def indefinite_generator(loader): | |
while True: | |
for x in loader: | |
yield x | |
def indefinite_generator_from_list(loaders): | |
while True: | |
random_idx = random.randint(0, len(loaders)-1) | |
for x in loaders[random_idx]: | |
yield x | |
break | |
def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0): | |
return torch.optim.Adam( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=lr, betas=betas, weight_decay=weight_decay) | |
class Fewshot_Trainer: | |
def __init__(self, cfgs, model): | |
# only now supports one gpu | |
self.cfgs = cfgs | |
# here should be the one gpu ddp setting | |
self.rank = cfgs.get('rank', 0) | |
self.world_size = cfgs.get('world_size', 1) | |
self.use_ddp = cfgs.get('use_ddp', True) | |
self.device = cfgs.get('device', 'cpu') | |
self.num_epochs = cfgs.get('num_epochs', 1) | |
self.lr = cfgs.get('few_shot_lr', 1e-4) | |
self.dataset = 'image' | |
self.metrics_trace = meters.MetricsTrace() | |
self.make_metrics = lambda m=None: meters.StandardMetrics(m) | |
self.archive_code = cfgs.get('archive_code', True) | |
self.batch_size = cfgs.get('batch_size', 64) | |
self.in_image_size = cfgs.get('in_image_size', 256) | |
self.out_image_size = cfgs.get('out_image_size', 256) | |
self.num_workers = cfgs.get('num_workers', 4) | |
self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') | |
misc.xmkdir(self.checkpoint_dir) | |
self.few_shot_resume = cfgs.get('few_shot_resume', False) | |
self.save_checkpoint_freq = cfgs.get('save_checkpoint_freq', 1) | |
self.keep_num_checkpoint = cfgs.get('keep_num_checkpoint', 2) # -1 for keeping all checkpoints | |
self.few_shot_data_dir = cfgs.get('few_shot_data_dir', None) | |
assert self.few_shot_data_dir is not None | |
# in case we add more data source | |
if isinstance(self.few_shot_data_dir, list): | |
self.few_shot_data_dir_more = self.few_shot_data_dir[1:] | |
self.few_shot_data_dir = self.few_shot_data_dir[0] | |
else: | |
self.few_shot_data_dir_more = None | |
assert "data_resize_update" in self.few_shot_data_dir # TODO: a hack way to make sure not using wrong data, needs to remove | |
self.few_shot_categories, self.few_shot_categories_paths = self.parse_few_shot_categories(self.few_shot_data_dir, self.few_shot_data_dir_more) | |
# if we need test categories, we pop it from self.few_shot_categories and self.few_shot_categories_path | |
# the test category needs to be category from few-shot, and we're using bs=1 on them, no need for back views enhancement (for now, use back view images, but don't duplicate them) | |
self.test_category_num = cfgs.get('few_shot_test_category_num', 0) | |
self.test_category_names = cfgs.get('few_shot_test_category_names', None) | |
if self.test_category_num > 0: | |
# if we have valid test_category names, then use them, the number doesn't need to be equal | |
if self.test_category_names is not None: | |
test_cats = self.test_category_names | |
else: | |
test_cats = list(self.few_shot_categories_paths.keys())[-(self.test_category_num):] | |
test_categories_paths = {} | |
for test_cat in test_cats: | |
test_categories_paths.update({test_cat: self.few_shot_categories_paths[test_cat]}) | |
assert test_cat in self.few_shot_categories | |
self.few_shot_categories.remove(test_cat) | |
self.few_shot_categories_paths.pop(test_cat) | |
self.test_categories_paths = test_categories_paths | |
else: | |
self.test_categories_paths = None | |
# also load the original 7 categories | |
self.original_train_data_path = cfgs.get('train_data_dir', None) | |
self.original_val_data_path = cfgs.get('val_data_dir', None) | |
self.original_categories = [] | |
self.original_categories_paths = self.original_train_data_path | |
for k, v in self.original_train_data_path.items(): | |
self.original_categories.append(k) | |
self.categories = self.original_categories + self.few_shot_categories | |
self.categories_paths = self.original_train_data_path.copy() | |
self.categories_paths.update(self.few_shot_categories_paths) | |
print(f'Using {len(self.categories)} cateogires: ', self.categories) | |
# initialize new things | |
# self.original_classes_num = cfgs.get('few_shot_original_classes_num', 7) | |
self.original_classes_num = len(self.original_categories) | |
self.new_classes_num = len(self.categories) - self.original_classes_num | |
self.combine_dataset = cfgs.get('combine_dataset', False) | |
assert self.combine_dataset, "we should use combine dataset, it's up to date" | |
if self.combine_dataset: | |
self.train_loader, self.val_loader, self.test_loader = self.get_data_loaders_quadrupeds(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) | |
else: | |
self.train_loader_few_shot, self.val_loader_few_shot = self.get_data_loaders_few_shot(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) | |
self.train_loader_original, self.val_loader_original = self.get_data_loaders_original(self.cfgs, self.batch_size, self.num_workers, self.in_image_size, self.out_image_size) | |
self.train_loader = self.train_loader_original + self.train_loader_few_shot | |
if self.val_loader_few_shot is not None and self.val_loader_original is not None: | |
self.val_loader = self.val_loader_original + self.val_loader_few_shot | |
self.num_iterations = cfgs.get('num_iterations', 0) | |
if self.num_iterations != 0: | |
self.use_total_iterations = True | |
else: | |
self.use_total_iterations = False | |
if self.use_total_iterations: | |
# reset the epoch related cfgs | |
dataloader_length = max([len(loader) for loader in self.train_loader]) * len(self.train_loader) | |
print("Total length of data loader is: ", dataloader_length) | |
total_epoch = int(self.num_iterations / dataloader_length) + 1 | |
print(f'run for {total_epoch} epochs') | |
print('is_main_process()?', misc.is_main_process()) | |
for k, v in cfgs.items(): | |
if 'epoch' in k: | |
# if isinstance(v, list): | |
# new_v = [int(total_epoch * x / 120) + 1 for x in v] | |
# cfgs[k] = new_v | |
# elif isinstance(v, int): | |
# new_v = int(total_epoch * v / 120) + 1 | |
# cfgs[k] = new_v | |
# a better transformation | |
if isinstance(v, int): | |
# use the floor int | |
new_v = int(total_epoch * v / 120) | |
cfgs[k] = new_v | |
elif isinstance(v, list): | |
if v[0] == v[1]: | |
# if the values in v are the same, then we use both the floor value | |
new_v = [int(total_epoch * x / 120) for x in v] | |
else: | |
# if the values are not the same, make the first using floor value and others using ceil value | |
new_v = [int(total_epoch * x / 120) + 1 for x in v] | |
new_v[0] = new_v[0] - 1 | |
cfgs[k] = new_v | |
else: | |
continue | |
self.num_epochs = total_epoch | |
self.cub_start_epoch = cfgs.get('cub_start_epoch', 0) | |
self.cfgs = cfgs | |
# the model is with nothing now | |
self.model = model(cfgs) | |
self.metrics_trace = meters.MetricsTrace() | |
self.make_metrics = lambda m=None: meters.StandardMetrics(m) | |
self.use_logger = True | |
self.log_freq_images = cfgs.get('log_freq_images', 1000) | |
self.log_train_images = cfgs.get('log_train_images', False) | |
self.log_freq_losses = cfgs.get('log_freq_losses', 100) | |
self.save_result_freq = cfgs.get('save_result_freq', None) | |
self.train_result_dir = osp.join(self.checkpoint_dir, 'results') | |
self.fix_viz_batch = cfgs.get('fix_viz_batch', False) | |
self.visualize_validation = cfgs.get('visualize_validation', False) | |
# self.visualize_validation = False | |
self.iteration_save = cfgs.get('few_shot_iteration_save', False) | |
self.iteration_save_freq = cfgs.get('few_shot_iteration_save_freq', 2000) | |
self.enable_memory_bank = cfgs.get('enable_memory_bank', False) | |
if self.enable_memory_bank: | |
self.memory_bank_dim = 128 | |
self.memory_bank_size = cfgs.get('memory_bank_size', 60) | |
self.memory_bank_topk = cfgs.get('memory_bank_topk', 10) | |
# assert self.memory_bank_topk < self.memory_bank_size | |
assert self.memory_bank_topk <= self.memory_bank_size | |
self.memory_retrieve = cfgs.get('memory_retrieve', 'cos-linear') | |
self.memory_bank_init = cfgs.get('memory_bank_init', 'random') | |
if self.memory_bank_init == 'copy': | |
# use trained 7 embeddings to initialize | |
num_piece = self.memory_bank_size // self.original_classes_num | |
num_left = self.memory_bank_size - num_piece * self.original_classes_num | |
tmp_1 = torch.empty_like(self.model.netPrior.classes_vectors) | |
tmp_1 = tmp_1.copy_(self.model.netPrior.classes_vectors) | |
tmp_1 = tmp_1.unsqueeze(0).repeat(num_piece, 1, 1) | |
tmp_1 = tmp_1.reshape(tmp_1.shape[0] * tmp_1.shape[1], tmp_1.shape[-1]) | |
if num_left > 0: | |
tmp_2 = torch.empty_like(self.model.netPrior.classes_vectors) | |
tmp_2 = tmp_2.copy_(self.model.netPrior.classes_vectors) | |
tmp_2 = tmp_2[:num_left] | |
tmp = torch.cat([tmp_1, tmp_2], dim=0) | |
else: | |
tmp = tmp_1 | |
self.memory_bank = torch.nn.Parameter(tmp, requires_grad=True) | |
elif self.memory_bank_init == 'random': | |
self.memory_bank = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(self.memory_bank_size, self.memory_bank_dim), a=-0.05, b=0.05), requires_grad=True) | |
else: | |
raise NotImplementedError | |
self.memory_encoder = cfgs.get('memory_encoder', 'DINO') # if DINO then just use the network encoder | |
if self.memory_encoder == 'CLIP': | |
self.clip_model, _ = clip.load('ViT-B/32', self.device) | |
self.clip_model = self.clip_model.eval().requires_grad_(False) | |
self.clip_mean = [0.48145466, 0.4578275, 0.40821073] | |
self.clip_std = [0.26862954, 0.26130258, 0.27577711] | |
self.clip_reso = 224 | |
self.memory_bank_keys_dim = 512 | |
elif self.memory_encoder == 'DINO': | |
self.memory_bank_keys_dim = 384 | |
else: | |
raise NotImplementedError | |
memory_bank_keys = torch.nn.init.uniform_(torch.empty(self.memory_bank_size, self.memory_bank_keys_dim), a=-0.05, b=0.05) | |
self.memory_bank_keys = torch.nn.Parameter(memory_bank_keys, requires_grad=True) | |
else: | |
print("no memory bank, just use image embedding, this is only for one experiment!") | |
self.memory_encoder = cfgs.get('memory_encoder', 'DINO') # if DINO then just use the network encoder | |
if self.memory_encoder == 'CLIP': | |
self.clip_model, _ = clip.load('ViT-B/32', self.device) | |
self.clip_model = self.clip_model.eval().requires_grad_(False) | |
self.clip_mean = [0.48145466, 0.4578275, 0.40821073] | |
self.clip_std = [0.26862954, 0.26130258, 0.27577711] | |
self.clip_reso = 224 | |
self.memory_bank_keys_dim = 512 | |
elif self.memory_encoder == 'DINO': | |
self.memory_bank_keys_dim = 384 | |
else: | |
raise NotImplementedError | |
self.prepare_model() | |
def parse_few_shot_categories(self, data_dir, data_dir_more=None): | |
# parse the categories data_dir | |
few_shot_category_num = self.cfgs.get('few_shot_category_num', -1) | |
assert few_shot_category_num != 0 | |
categories = sorted(os.listdir(data_dir)) | |
cnt = 0 | |
category_names = [] | |
category_names_paths = {} | |
for category in categories: | |
if osp.isdir(osp.join(self.few_shot_data_dir, category, 'train')): | |
category_path = osp.join(self.few_shot_data_dir, category, 'train') | |
category_names.append(category) | |
category_names_paths.update({category: category_path}) | |
cnt += 1 | |
if few_shot_category_num > 0 and cnt >= few_shot_category_num: | |
break | |
# more data | |
if data_dir_more is not None: | |
for data_dir_one in data_dir_more: | |
new_categories = os.listdir(data_dir_one) | |
for new_category in new_categories: | |
''' | |
if this category is not used before, add a new item | |
if there is this category before, add the paths to original paths, | |
if its a str, make it a list | |
if its already a list, append it | |
''' | |
if new_category not in category_names: | |
#TODO: a hacky way here, if in new data there is category used in 7-cat, we just make it a new one | |
if new_category in list(self.cfgs.get('train_data_dir', None).keys()): | |
new_category = '_' + new_category | |
category_names.append(new_category) | |
category_names_paths.update({ | |
new_category: osp.join(data_dir_one, new_category, 'train') | |
}) | |
else: | |
old_category_path = category_names_paths[new_category] | |
if isinstance(old_category_path, str): | |
category_names_paths[new_category] = [ | |
old_category_path, | |
osp.join(data_dir_one, new_category, 'train') | |
] | |
elif isinstance(old_category_path, list): | |
old_category_path = old_category_path + [osp.join(data_dir_one, new_category, 'train')] | |
category_names_paths[new_category] = old_category_path | |
else: | |
raise NotImplementedError | |
# category_names = sorted(category_names) | |
return category_names, category_names_paths | |
def prepare_model(self): | |
# here we prepare the model weights at outside | |
# 1. load the pretrain weight | |
# 2. initialize anything new, like new class vectors | |
# 3. initialize new optimizer for chosen parameters | |
assert self.original_classes_num == len(self.model.netPrior.category_id_map) | |
# load pretrain | |
# if not assigned few_shot_checkpoint_name, then skip this part | |
if self.cfgs.get('few_shot_checkpoint_name', None) is not None: | |
original_checkpoint_path = osp.join(self.checkpoint_dir, self.cfgs.get('few_shot_checkpoint_name', 'checkpoint060.pth')) | |
assert osp.exists(original_checkpoint_path) | |
print(f"Loading pre-trained checkpoint from {original_checkpoint_path}") | |
cp = torch.load(original_checkpoint_path, map_location=self.device) | |
# if using local-texture network in fine-tuning, the texture in previous pre-train ckpt is global | |
# here we use a hack way, we just get rid of original texture ckpt | |
if (self.cfgs.get('texture_way', None) is not None) or (self.cfgs.get('texture_act', 'relu') != 'relu'): | |
new_netInstance_weights = {k: v for k, v in cp['netInstance'].items() if 'netTexture' not in k} | |
#find the new texture weights | |
texture_weights = self.model.netInstance.netTexture.state_dict() | |
#add the new weights to the new model weights | |
for k, v in texture_weights.items(): | |
# for the overlapping part in netTexture, we also use them | |
# if ('netTexture.' + k) in cp['netInstance'].keys(): | |
# new_netInstance_weights['netTexture.' + k] = cp['netInstance']['netTexture.' + k] | |
# else: | |
# new_netInstance_weights['netTexture.' + k] = v | |
new_netInstance_weights['netTexture.' + k] = v | |
_ = cp.pop("netInstance") | |
cp.update({"netInstance": new_netInstance_weights}) | |
self.model.netInstance.load_state_dict(cp["netInstance"], strict=False) # For Deform | |
# self.model.netInstance.load_state_dict(cp["netInstance"]) | |
self.model.netPrior.load_state_dict(cp["netPrior"]) | |
self.original_total_iter = cp["total_iter"] | |
else: | |
print("not load any pre-train weight, the iter will start from 0, make sure you set all the needed parameters") | |
self.original_total_iter = 0 | |
if not self.cfgs.get('disable_fewshot', False): | |
for i, category in enumerate(self.few_shot_categories): | |
category_id = self.original_classes_num + i | |
self.model.netPrior.category_id_map.update({category: category_id}) | |
few_shot_class_vector_init = self.cfgs.get('few_shot_class_vector_init', 'random') | |
if few_shot_class_vector_init == 'random': | |
tmp = torch.nn.init.uniform_(torch.empty(self.new_classes_num, self.model.netPrior.classes_vectors.shape[-1]), a=-0.05, b=0.05) | |
tmp = tmp.to(self.model.netPrior.classes_vectors.device) | |
self.model.netPrior.classes_vectors = torch.nn.Parameter(torch.cat([self.model.netPrior.classes_vectors, tmp], dim=0)) | |
elif few_shot_class_vector_init == 'copy': | |
num_7_cat_piece = self.new_classes_num // self.original_classes_num if self.new_classes_num > self.original_classes_num else 0 | |
num_left = self.new_classes_num - num_7_cat_piece * self.original_classes_num | |
if num_7_cat_piece > 0: | |
tmp_1 = torch.empty_like(self.model.netPrior.classes_vectors) | |
tmp_1 = tmp_1.copy_(self.model.netPrior.classes_vectors) | |
tmp_1 = tmp_1.unsqueeze(0).repeat(num_7_cat_piece, 1, 1) | |
tmp_1 = tmp_1.reshape(tmp_1.shape[0] * tmp_1.shape[1], tmp_1.shape[-1]) | |
else: | |
tmp_1 = None | |
if num_left > 0: | |
tmp_2 = torch.empty_like(self.model.netPrior.classes_vectors) | |
tmp_2 = tmp_2.copy_(self.model.netPrior.classes_vectors) | |
tmp_2 = tmp_2[:num_left] | |
else: | |
tmp_2 = None | |
if tmp_1 != None and tmp_2 != None: | |
tmp = torch.cat([tmp_1, tmp_2], dim=0) | |
elif tmp_1 == None and tmp_2 != None: | |
tmp = tmp_2 | |
elif tmp_2 == None and tmp_1 != None: | |
tmp = tmp_1 | |
else: | |
raise NotImplementedError | |
tmp = tmp.to(self.model.netPrior.classes_vectors.device) | |
self.model.netPrior.classes_vectors = torch.nn.Parameter(torch.cat([self.model.netPrior.classes_vectors, tmp], dim=0)) | |
else: | |
raise NotImplementedError | |
else: | |
print("disable few shot, not increasing embedding vectors") | |
# initialize new optimizer | |
optimize_rule = self.cfgs.get('few_shot_optimize', 'all') | |
if optimize_rule == 'all': | |
optimize_list = [ | |
{'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, | |
{'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, | |
] | |
elif optimize_rule == 'only-emb': | |
optimize_list = [ | |
{'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 10.} | |
] | |
elif optimize_rule == 'emb-instance': | |
optimize_list = [ | |
{'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 10.}, | |
{'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, | |
] | |
elif optimize_rule == 'custom': | |
optimize_list = [ | |
{'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, | |
{'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 0.01}, | |
{'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.} | |
] | |
elif optimize_rule == 'custom-deform': | |
optimize_list = [ | |
{'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, | |
{'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 0.01}, | |
{'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netDeform', 'params': list(self.model.netInstance.netDeform.parameters()), 'lr': self.lr * 1.} | |
] | |
elif optimize_rule == 'texture': | |
optimize_list = [ | |
{'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.} | |
] | |
elif optimize_rule == 'texture-light': | |
optimize_list = [ | |
{'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.} | |
] | |
elif optimize_rule == 'exp': | |
optimize_list = [ | |
{'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, | |
{'name': 'netEncoder', 'params': list(self.model.netInstance.netEncoder.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netTexture', 'params': list(self.model.netInstance.netTexture.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netPose', 'params': list(self.model.netInstance.netPose.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netArticulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netLight', 'params': list(self.model.netInstance.netLight.parameters()), 'lr': self.lr * 1.}, | |
{'name': 'netDeform', 'params': list(self.model.netInstance.netDeform.parameters()), 'lr': self.lr * 1.} | |
] | |
else: | |
raise NotImplementedError | |
if self.enable_memory_bank and optimize_rule != 'texture': | |
optimize_bank_components = self.cfgs.get('few_shot_optimize_bank', 'all') | |
if optimize_bank_components == 'value': | |
optimize_list += [ | |
{'name': 'memory_bank', 'params': list([self.memory_bank]), 'lr': self.lr * 10.} | |
] | |
elif optimize_bank_components == 'key': | |
optimize_list += [ | |
{'name': 'memory_bank_keys', 'params': list([self.memory_bank_keys]), 'lr': self.lr * 10.} | |
] | |
else: | |
optimize_list += [ | |
{'name': 'memory_bank', 'params': list([self.memory_bank]), 'lr': self.lr * 10.}, | |
{'name': 'memory_bank_keys', 'params': list([self.memory_bank_keys]), 'lr': self.lr * 10.} | |
] | |
if self.model.enable_vsd: | |
optimize_list += [ | |
{'name': 'lora', 'params': list(self.model.stable_diffusion.parameters()), 'lr': self.lr} | |
] | |
# self.optimizerFewShot = torch.optim.Adam( | |
# [ | |
# # {'name': 'class_embeddings', 'params': list([self.model.netPrior.classes_vectors]), 'lr': self.lr * 1.}, | |
# {'name': 'net_Prior', 'params': list(self.model.netPrior.parameters()), 'lr': self.lr * 10.}, | |
# {'name': 'net_Instance', 'params': list(self.model.netInstance.parameters()), 'lr': self.lr * 1.}, | |
# # {'name': 'net_articulation', 'params': list(self.model.netInstance.netArticulation.parameters()), 'lr': self.lr * 10.} | |
# ], betas=(0.9, 0.99), eps=1e-15 | |
# ) | |
self.optimizerFewShot = torch.optim.Adam(optimize_list, betas=(0.9, 0.99), eps=1e-15) | |
# if self.cfgs.get('texture_way', None) is not None and self.cfgs.get('gan_tex', False): | |
if self.cfgs.get('gan_tex', False): | |
self.optimizerDiscTex = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.discriminator_texture.parameters()), lr=self.lr, betas=(0.9, 0.99), eps=1e-15) | |
def load_checkpoint(self, optim=True, checkpoint_name=None): | |
# use to load the checkpoint of model and optimizer in the finetuning | |
"""Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer.""" | |
if checkpoint_name is not None: | |
checkpoint_path = osp.join(self.checkpoint_dir, checkpoint_name) | |
else: | |
checkpoints = sorted(glob.glob(osp.join(self.checkpoint_dir, '*.pth'))) | |
if len(checkpoints) == 0: | |
return 0, 0 | |
checkpoint_path = checkpoints[-1] | |
self.checkpoint_name = osp.basename(checkpoint_path) | |
print(f"Loading checkpoint from {checkpoint_path}") | |
cp = torch.load(checkpoint_path, map_location=self.device) | |
self.model.load_model_state(cp) # the cp has netPrior and netInstance as keys | |
if optim: | |
try: | |
self.optimizerFewShot.load_state_dict(cp['optimizerFewShot']) | |
except: | |
print('you should be using the local texture so dont need to load the previous optimizer') | |
if self.enable_memory_bank: | |
self.memory_bank_keys = cp['memory_bank_keys'] | |
self.memory_bank = cp['memory_bank'] | |
self.metrics_trace = cp['metrics_trace'] | |
epoch = cp['epoch'] | |
total_iter = cp['total_iter'] | |
return epoch, total_iter | |
def save_checkpoint(self, epoch, total_iter=0, optim=True, use_iter=False): | |
"""Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch.""" | |
misc.xmkdir(self.checkpoint_dir) | |
if use_iter: | |
checkpoint_path = osp.join(self.checkpoint_dir, f'iter{total_iter:07}.pth') | |
prefix = 'iter*.pth' | |
else: | |
checkpoint_path = osp.join(self.checkpoint_dir, f'checkpoint{epoch:03}.pth') | |
prefix = 'checkpoint*.pth' | |
state_dict = self.model.get_model_state() | |
if optim: | |
optimizer_state = {'optimizerFewShot': self.optimizerFewShot.state_dict()} | |
state_dict = {**state_dict, **optimizer_state} | |
state_dict['metrics_trace'] = self.metrics_trace | |
state_dict['epoch'] = epoch | |
state_dict['total_iter'] = total_iter | |
if self.enable_memory_bank: | |
state_dict['memory_bank_keys'] = self.memory_bank_keys | |
state_dict['memory_bank'] = self.memory_bank | |
print(f"Saving checkpoint to {checkpoint_path}") | |
torch.save(state_dict, checkpoint_path) | |
if self.keep_num_checkpoint > 0: | |
self.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint, prefix=prefix) | |
def clean_checkpoint(self, checkpoint_dir, keep_num=2, prefix='checkpoint*.pth'): | |
if keep_num > 0: | |
names = list(sorted( | |
glob.glob(os.path.join(checkpoint_dir, prefix)) | |
)) | |
if len(names) > keep_num: | |
for name in names[:-keep_num]: | |
print(f"Deleting obslete checkpoint file {name}") | |
os.remove(name) | |
def get_data_loaders_few_shot(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): | |
# support the train_data_loaders, and also an identical val_data_loader? | |
train_loader = val_loader = None | |
color_jitter_train = cfgs.get('color_jitter_train', None) | |
color_jitter_val = cfgs.get('color_jitter_val', None) | |
random_flip_train = cfgs.get('random_flip_train', False) | |
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') | |
num_sample_frames = cfgs.get('num_sample_frames', 2) | |
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) | |
load_background = cfgs.get('background_mode', 'none') == 'background' | |
rgb_suffix = cfgs.get('rgb_suffix', '.png') | |
load_dino_feature = cfgs.get('load_dino_feature', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( | |
mode=data_loader_mode, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
num_sample_frames=num_sample_frames, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=0, | |
**kwargs) | |
print(f"Loading training data...") | |
train_loader = get_loader_ddp(data_dir=[self.original_classes_num, self.few_shot_categories_paths], rank=self.rank, world_size=self.world_size, use_few_shot=True, shuffle=False, color_jitter=color_jitter_train, random_flip=random_flip_train) | |
return train_loader, val_loader | |
def get_data_loaders_original(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): | |
train_loader = val_loader = test_loader = None | |
color_jitter_train = cfgs.get('color_jitter_train', None) | |
color_jitter_val = cfgs.get('color_jitter_val', None) | |
random_flip_train = cfgs.get('random_flip_train', False) | |
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') | |
skip_beginning = cfgs.get('skip_beginning', 4) | |
skip_end = cfgs.get('skip_end', 4) | |
num_sample_frames = cfgs.get('num_sample_frames', 2) | |
min_seq_len = cfgs.get('min_seq_len', 10) | |
max_seq_len = cfgs.get('max_seq_len', 10) | |
debug_seq = cfgs.get('debug_seq', False) | |
random_sample_train_frames = cfgs.get('random_sample_train_frames', False) | |
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) | |
random_sample_val_frames = cfgs.get('random_sample_val_frames', False) | |
load_background = cfgs.get('background_mode', 'none') == 'background' | |
rgb_suffix = cfgs.get('rgb_suffix', '.png') | |
load_dino_feature = cfgs.get('load_dino_feature', False) | |
load_dino_cluster = cfgs.get('load_dino_cluster', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( | |
mode=data_loader_mode, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=0, | |
**kwargs) | |
# just the train now | |
train_data_dir = self.original_categories_paths | |
if isinstance(train_data_dir, dict): | |
for data_path in train_data_dir.values(): | |
assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" | |
elif isinstance(train_data_dir, str): | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
else: | |
raise ValueError("train_data_dir must be a string or a dict of strings") | |
print(f"Loading training data...") | |
# the train_data_dir is a dict and will go into the original dataset type | |
train_loader = get_loader_ddp(data_dir=train_data_dir, rank=self.rank, world_size=self.world_size, is_validation=False, use_few_shot=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) | |
return train_loader, val_loader | |
def get_data_loaders_quadrupeds(self, cfgs, batch_size, num_workers, in_image_size, out_image_size): | |
train_loader = val_loader = test_loader = None | |
color_jitter_train = cfgs.get('color_jitter_train', None) | |
color_jitter_val = cfgs.get('color_jitter_val', None) | |
random_flip_train = cfgs.get('random_flip_train', False) | |
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') | |
skip_beginning = cfgs.get('skip_beginning', 4) | |
skip_end = cfgs.get('skip_end', 4) | |
num_sample_frames = cfgs.get('num_sample_frames', 2) | |
min_seq_len = cfgs.get('min_seq_len', 10) | |
max_seq_len = cfgs.get('max_seq_len', 10) | |
debug_seq = cfgs.get('debug_seq', False) | |
random_sample_train_frames = cfgs.get('random_sample_train_frames', False) | |
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) | |
random_sample_val_frames = cfgs.get('random_sample_val_frames', False) | |
load_background = cfgs.get('background_mode', 'none') == 'background' | |
rgb_suffix = cfgs.get('rgb_suffix', '.png') | |
load_dino_feature = cfgs.get('load_dino_feature', False) | |
load_dino_cluster = cfgs.get('load_dino_cluster', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
enhance_back_view = cfgs.get('enhance_back_view', False) | |
enhance_back_view_path = cfgs.get('enhance_back_view_path', None) | |
override_categories = cfgs.get('override_categories', None) | |
disable_fewshot = cfgs.get('disable_fewshot', False) | |
dataset_split_num = cfgs.get('dataset_split_num', -1) | |
get_loader_ddp = lambda **kwargs: get_sequence_loader_quadrupeds( | |
mode=data_loader_mode, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=0, | |
enhance_back_view=enhance_back_view, | |
enhance_back_view_path=enhance_back_view_path, | |
override_categories=override_categories, | |
disable_fewshot=disable_fewshot, | |
dataset_split_num=dataset_split_num, | |
**kwargs) | |
# just the train now | |
print(f"Loading training data...") | |
val_image_num = cfgs.get('few_shot_val_image_num', 5) | |
# the train_data_dir is a dict and will go into the original dataset type | |
train_loader = get_loader_ddp(original_data_dirs=self.original_categories_paths, few_shot_data_dirs=self.few_shot_categories_paths, original_num=self.original_classes_num, few_shot_num=self.new_classes_num, rank=self.rank, world_size=self.world_size, batch_size=batch_size, is_validation=False, val_image_num=val_image_num, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) | |
val_loader = get_loader_ddp(original_data_dirs=self.original_val_data_path, few_shot_data_dirs=self.few_shot_categories_paths, original_num=self.original_classes_num, few_shot_num=self.new_classes_num, rank=self.rank, world_size=self.world_size, batch_size=1, is_validation=True, val_image_num=val_image_num, shuffle=False, dense_sample=True, color_jitter=color_jitter_val, random_flip=False) | |
if self.test_categories_paths is not None: | |
get_test_loader_ddp = lambda **kwargs: get_test_loader_quadrupeds( | |
mode=data_loader_mode, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=0, | |
enhance_back_view=enhance_back_view, | |
enhance_back_view_path=enhance_back_view_path, | |
**kwargs) | |
print(f"Loading testing data...") | |
test_loader = get_test_loader_ddp(test_data_dirs=self.test_categories_paths, rank=self.rank, world_size=self.world_size, batch_size=1, is_validation=True, shuffle=False, dense_sample=True, color_jitter=color_jitter_val, random_flip=False) | |
else: | |
test_loader = None | |
return train_loader, val_loader, test_loader | |
def forward_frozen_ViT(self, images): | |
# this part use the frozen pre-train ViT | |
x = images | |
with torch.no_grad(): | |
b, c, h, w = x.shape | |
self.model.netInstance.netEncoder._feats = [] | |
self.model.netInstance.netEncoder._register_hooks([11], 'key') | |
#self._register_hooks([11], 'token') | |
x = self.model.netInstance.netEncoder.ViT.prepare_tokens(x) | |
#x = self.ViT.prepare_tokens_with_masks(x) | |
for blk in self.model.netInstance.netEncoder.ViT.blocks: | |
x = blk(x) | |
out = self.model.netInstance.netEncoder.ViT.norm(x) | |
self.model.netInstance.netEncoder._unregister_hooks() | |
ph, pw = h // self.model.netInstance.netEncoder.patch_size, w // self.model.netInstance.netEncoder.patch_size | |
patch_out = out[:, 1:] # first is class token | |
patch_out = patch_out.reshape(b, ph, pw, self.model.netInstance.netEncoder.vit_feat_dim).permute(0, 3, 1, 2) | |
patch_key = self.model.netInstance.netEncoder._feats[0][:,:,1:] # B, num_heads, num_patches, dim | |
patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.model.netInstance.netEncoder.vit_feat_dim, ph, pw) | |
global_feat = out[:, 0] | |
return global_feat | |
def forward_fix_embeddings(self, batch): | |
images = batch[0] | |
images = images.to(self.device) | |
batch_size, num_frames, _, h0, w0 = images.shape | |
images = images.reshape(batch_size*num_frames, *images.shape[2:]) # 0~1 | |
if self.memory_encoder == 'DINO': | |
images_in = images * 2 - 1 # rescale to (-1, 1) | |
batch_features = self.forward_frozen_ViT(images_in) | |
elif self.memory_encoder == 'CLIP': | |
images_in = torch.nn.functional.interpolate(images, (self.clip_reso, self.clip_reso), mode='bilinear') | |
images_in = tvf.normalize(images_in, self.clip_mean, self.clip_std) | |
batch_features = self.clip_model.encode_image(images_in).float() | |
else: | |
raise NotImplementedError | |
return batch_features | |
def retrieve_memory_bank(self, batch_features, batch): | |
batch_size = batch_features.shape[0] | |
if self.memory_retrieve == 'cos-linear': | |
query = torch.nn.functional.normalize(batch_features.unsqueeze(1), dim=-1) # [B, 1, d_k] | |
key = torch.nn.functional.normalize(self.memory_bank_keys, dim=-1) # [size, d_k] | |
key = key.transpose(1, 0).unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, d_k, size] | |
cos_dist = torch.bmm(query, key).squeeze(1) # [B, size], larger the more similar | |
rank_idx = torch.sort(cos_dist, dim=-1, descending=True)[1][:, :self.memory_bank_topk] # [B, k] | |
value = self.memory_bank.unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, size, d_v] | |
out = torch.gather(value, dim=1, index=rank_idx[..., None].repeat(1, 1, self.memory_bank_dim)) # [B, k, d_v] | |
weights = torch.gather(cos_dist, dim=-1, index=rank_idx) # [B, k] | |
weights = torch.nn.functional.normalize(weights, p=1.0, dim=-1).unsqueeze(-1).repeat(1, 1, self.memory_bank_dim) # [B, k, d_v] weights have been normalized | |
out = weights * out | |
out = torch.sum(out, dim=1) | |
else: | |
raise NotImplementedError | |
batch_mean_out = torch.mean(out, dim=0) | |
weight_aux = { | |
'weights': weights[:, :, 0], # [B, k], weights from large to small | |
'pick_idx': rank_idx, # [B, k] | |
} | |
return batch_mean_out, out, weight_aux | |
def discriminator_texture_step(self): | |
image_iv = self.model.record_image_iv | |
image_rv = self.model.record_image_rv | |
image_gt = self.model.record_image_gt | |
self.model.record_image_iv = None | |
self.model.record_image_rv = None | |
self.model.record_image_gt = None | |
image_iv = image_iv.requires_grad_(True) | |
image_rv = image_rv.requires_grad_(True) | |
image_gt = image_gt.requires_grad_(True) | |
self.optimizerDiscTex.zero_grad() | |
disc_loss_gt = 0.0 | |
disc_loss_iv = 0.0 | |
disc_loss_rv = 0.0 | |
grad_penalty = 0.0 | |
# for the gt image, it can only be in real or not | |
if 'gt' in self.model.few_shot_gan_tex_real: | |
d_gt = self.model.discriminator_texture(image_gt) | |
disc_loss_gt += discriminator_architecture.bce_loss_target(d_gt, 1) | |
if image_gt.requires_grad: | |
grad_penalty_gt = 10. * discriminator_architecture.compute_grad2(d_gt, image_gt) | |
disc_loss_gt += grad_penalty_gt | |
grad_penalty += grad_penalty_gt | |
# for the input view image, it can be in real or fake | |
if 'iv' in self.model.few_shot_gan_tex_real: | |
d_iv = self.model.discriminator_texture(image_iv) | |
disc_loss_iv += discriminator_architecture.bce_loss_target(d_iv, 1) | |
if image_iv.requires_grad: | |
grad_penalty_iv = 10. * discriminator_architecture.compute_grad2(d_iv, image_iv) | |
disc_loss_iv += grad_penalty_iv | |
grad_penalty += grad_penalty_iv | |
elif 'iv' in self.model.few_shot_gan_tex_fake: | |
d_iv = self.model.discriminator_texture(image_iv) | |
disc_loss_iv += discriminator_architecture.bce_loss_target(d_iv, 0) | |
# for the random view image, it can only be in fake | |
if 'rv' in self.model.few_shot_gan_tex_fake: | |
d_rv = self.model.discriminator_texture(image_rv) | |
disc_loss_rv += discriminator_architecture.bce_loss_target(d_rv, 0) | |
all_loss = disc_loss_iv + disc_loss_rv + disc_loss_gt | |
all_loss = all_loss * self.cfgs.get('gan_tex_loss_discriminator_weight', 0.1) | |
self.discriminator_texture_loss = all_loss | |
self.discriminator_texture_loss.backward() | |
self.optimizerDiscTex.step() | |
self.discriminator_texture_loss = 0. | |
return { | |
'discriminator_loss': all_loss.detach(), | |
'discriminator_loss_iv': disc_loss_iv.detach(), | |
'discriminator_loss_rv': disc_loss_rv.detach(), | |
'discriminator_loss_gt': disc_loss_gt.detach(), | |
'discriminator_loss_grad': grad_penalty.detach() | |
} | |
def train(self): | |
"""Perform training.""" | |
# archive code and configs | |
if self.archive_code: | |
misc.archive_code(osp.join(self.checkpoint_dir, 'archived_code.zip'), filetypes=['.py']) | |
misc.dump_yaml(osp.join(self.checkpoint_dir, 'configs.yml'), self.cfgs) | |
# initialize | |
start_epoch = 0 | |
self.total_iter = 0 | |
self.total_iter = self.original_total_iter | |
self.metrics_trace.reset() | |
self.model.to(self.device) | |
if self.model.enable_disc: | |
self.model.reset_only_disc_optimizer() | |
if self.few_shot_resume: | |
resume_model_name = self.cfgs.get('few_shot_resume_name', None) | |
start_epoch, self.total_iter = self.load_checkpoint(optim=True, checkpoint_name=resume_model_name) | |
self.model.ddp(self.rank, self.world_size) | |
# use tensorboard | |
if self.use_logger: | |
from torch.utils.tensorboard import SummaryWriter | |
self.logger = SummaryWriter(osp.join(self.checkpoint_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S")), flush_secs=10) | |
# self.viz_data_iterator = indefinite_generator_from_list(self.val_loader) if self.visualize_validation else indefinite_generator_from_list(self.train_loader) | |
self.viz_data_iterator = indefinite_generator(self.val_loader[0]) if self.visualize_validation else indefinite_generator(self.train_loader[0]) | |
if self.fix_viz_batch: | |
self.viz_batch = next(self.viz_data_iterator) | |
if self.test_loader is not None: | |
self.viz_test_data_iterator = indefinite_generator(self.test_loader[0]) if self.visualize_validation else indefinite_generator(self.train_loader[0]) | |
# run_epochs | |
epoch = 0 | |
for epoch in range(start_epoch, self.num_epochs): | |
metrics = self.run_epoch(epoch) | |
if self.combine_dataset: | |
self.train_loader[0].dataset._shuffle_all() | |
self.metrics_trace.append("train", metrics) | |
if (epoch+1) % self.save_checkpoint_freq == 0: | |
self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True) | |
# if self.cfgs.get('pyplot_metrics', True): | |
# self.metrics_trace.plot(pdf_path=osp.join(self.checkpoint_dir, 'metrics.pdf')) | |
self.metrics_trace.save(osp.join(self.checkpoint_dir, 'metrics.json')) | |
print(f"Training completed for all {epoch+1} epochs.") | |
def run_epoch(self, epoch): | |
"""Run one training epoch.""" | |
metrics = self.make_metrics() | |
self.model.set_train() | |
max_loader_len = max([len(loader) for loader in self.train_loader]) | |
train_generators = [indefinite_generator(loader) for loader in self.train_loader] | |
iteration = 0 | |
while iteration < max_loader_len * len(self.train_loader): | |
for generator in train_generators: | |
batch = next(generator) | |
self.total_iter += 1 | |
num_seqs, num_frames = batch[0].shape[:2] | |
total_im_num = num_seqs * num_frames | |
if self.enable_memory_bank: | |
batch_features = self.forward_fix_embeddings(batch) | |
batch_embedding, embeddings, weights = self.retrieve_memory_bank(batch_features, batch) | |
bank_embedding_model_input = [batch_embedding, embeddings, weights] | |
else: | |
# bank_embedding_model_input = None | |
batch_features = self.forward_fix_embeddings(batch) | |
weights = { | |
"weights": torch.rand(1,10).to(batch_features.device), | |
"pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features.device) | |
} | |
bank_embedding_model_input = [batch_features[0], batch_features, weights] | |
m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, which_data=self.dataset, is_training=True, bank_embedding=bank_embedding_model_input) | |
# self.model.backward() | |
self.optimizerFewShot.zero_grad() | |
self.model.total_loss.backward() | |
self.optimizerFewShot.step() | |
self.model.total_loss = 0. | |
# if self.cfgs.get('texture_way', None) is not None and self.cfgs.get('gan_tex', False): | |
if self.model.few_shot_gan_tex: | |
# the discriminator for local texture | |
disc_ret = self.discriminator_texture_step() | |
m.update(disc_ret) | |
if self.model.enable_disc and (self.model.mask_discriminator_iter[0] < self.total_iter) and (self.model.mask_discriminator_iter[1] > self.total_iter): | |
# the discriminator training | |
discriminator_loss_dict, grad_loss = self.model.discriminator_step() | |
m.update( | |
{ | |
'mask_disc_loss_discriminator': discriminator_loss_dict['discriminator_loss'] - grad_loss, | |
'mask_disc_loss_discriminator_grad': grad_loss, | |
'mask_disc_loss_discriminator_rv': discriminator_loss_dict['discriminator_loss_rv'], | |
'mask_disc_loss_discriminator_iv': discriminator_loss_dict['discriminator_loss_iv'], | |
'mask_disc_loss_discriminator_gt': discriminator_loss_dict['discriminator_loss_gt'] | |
} | |
) | |
self.logger.add_histogram('train_'+'discriminator_logits/random_view', discriminator_loss_dict['d_rv'], self.total_iter) | |
if discriminator_loss_dict['d_iv'] is not None: | |
self.logger.add_histogram('train_'+'discriminator_logits/input_view', discriminator_loss_dict['d_iv'], self.total_iter) | |
if discriminator_loss_dict['d_gt'] is not None: | |
self.logger.add_histogram('train_'+'discriminator_logits/gt_view', discriminator_loss_dict['d_gt'], self.total_iter) | |
metrics.update(m, total_im_num) | |
if self.rank == 0: | |
print(f"T{epoch:04}/{iteration:05}/{metrics}") | |
if self.iteration_save and self.total_iter % self.iteration_save_freq == 0: | |
self.save_checkpoint(epoch+1, total_iter=self.total_iter, optim=True, use_iter=True) | |
# ## reset optimizers | |
# if self.cfgs.get('opt_reset_every_iter', 0) > 0 and self.total_iter < self.cfgs.get('opt_reset_end_iter', 0): | |
# if self.total_iter % self.cfgs.get('opt_reset_every_iter', 0) == 0: | |
# self.model.reset_optimizers() | |
if misc.is_main_process() and self.use_logger: | |
if self.rank == 0 and self.total_iter % self.log_freq_losses == 0: | |
for name, loss in m.items(): | |
label = f'cub_loss_train/{name[4:]}' if 'cub' in name else f'loss_train/{name}' | |
self.logger.add_scalar(label, loss, self.total_iter) | |
if self.rank == 0 and self.save_result_freq is not None and self.total_iter % self.save_result_freq == 0: | |
with torch.no_grad(): | |
m = self.model.forward(batch, epoch=epoch, iter=iteration, total_iter=self.total_iter, save_results=False, save_dir=self.train_result_dir, which_data=self.dataset, is_training=False, bank_embedding=bank_embedding_model_input) | |
torch.cuda.empty_cache() | |
if self.total_iter % self.log_freq_images == 0: | |
with torch.no_grad(): | |
if self.rank == 0 and self.log_train_images: | |
m = self.model.forward(batch, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='train_', is_training=False, bank_embedding=bank_embedding_model_input) | |
if self.fix_viz_batch: | |
print(f'fix_viz_batch:{self.fix_viz_batch}') | |
batch_val = self.viz_batch | |
else: | |
batch_val = next(self.viz_data_iterator) | |
if self.visualize_validation: | |
import time | |
vis_start = time.time() | |
# batch = next(self.viz_data_iterator) | |
# try: | |
# batch = next(self.viz_data_iterator) | |
# except: # iterator exhausted | |
# self.reset_viz_data_iterator() | |
# batch = next(self.viz_data_iterator) | |
if self.enable_memory_bank: | |
batch_features_val = self.forward_fix_embeddings(batch_val) | |
batch_embedding_val, embeddings_val, weights_val = self.retrieve_memory_bank(batch_features_val, batch_val) | |
bank_embedding_model_input_val = [batch_embedding_val, embeddings_val, weights_val] | |
else: | |
# bank_embedding_model_input_val = None | |
batch_features_val = self.forward_fix_embeddings(batch_val) | |
weights_val = { | |
"weights": torch.rand(1,10).to(batch_features_val.device), | |
"pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features_val.device) | |
} | |
bank_embedding_model_input_val = [batch_features_val[0], batch_features_val, weights_val] | |
if self.total_iter % self.save_result_freq == 0: | |
m = self.model.forward(batch_val, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, save_results=False, save_dir=self.train_result_dir, which_data=self.dataset, logger_prefix='val_', is_training=False, bank_embedding=bank_embedding_model_input_val) | |
torch.cuda.empty_cache() | |
vis_end = time.time() | |
print(f"vis time: {vis_end - vis_start}") | |
if self.test_loader is not None: | |
# unseen category test visualization | |
batch_test = next(self.viz_test_data_iterator) | |
if self.enable_memory_bank: | |
batch_features_test = self.forward_fix_embeddings(batch_test) | |
batch_embedding_test, embeddings_test, weights_test = self.retrieve_memory_bank(batch_features_test, batch_test) | |
bank_embedding_model_input_test = [batch_embedding_test, embeddings_test, weights_test] | |
else: | |
# bank_embedding_model_input_test = None | |
batch_features_test = self.forward_fix_embeddings(batch_test) | |
weights_test = { | |
"weights": torch.rand(1,10).to(batch_features_test.device), | |
"pick_idx": torch.randint(low=0, high=60, size=(1, 10)).to(batch_features_test.device) | |
} | |
bank_embedding_model_input_test = [batch_features_test[0], batch_features_test, weights_test] | |
m_test = self.model.forward(batch_test, epoch=epoch, iter=iteration, viz_logger=self.logger, total_iter=self.total_iter, which_data=self.dataset, logger_prefix='test_', is_training=False, bank_embedding=bank_embedding_model_input_test) | |
vis_test_end = time.time() | |
print(f"vis test time: {vis_test_end - vis_end}") | |
for name, loss in m_test.items(): | |
if self.rank == 0: | |
self.logger.add_scalar(f'loss_test/{name}', loss, self.total_iter) | |
for name, loss in m.items(): | |
if self.rank == 0: | |
self.logger.add_scalar(f'loss_val/{name}', loss, self.total_iter) | |
torch.cuda.empty_cache() | |
iteration += 1 | |
self.model.scheduler_step() | |
return metrics | |