Weakly-Supervised-3DOD / tools /train_score.py
AndreasLH's picture
upload repo
56bd2b5
# Copyright (c) Meta Platforms, Inc. and affiliates
import sys
import warnings
warnings.filterwarnings("ignore", message="Overwriting tiny_vit_21m_512 in registry")
warnings.filterwarnings("ignore", message="Overwriting tiny_vit_21m_384 in registry")
warnings.filterwarnings("ignore", message="Overwriting tiny_vit_21m_224 in registry")
warnings.filterwarnings("ignore", message="Overwriting tiny_vit_11m_224 in registry")
warnings.filterwarnings("ignore", message="Overwriting tiny_vit_5m_224 in registry")
from cubercnn.data.generate_ground_segmentations import init_segmentation
import logging
import os
import torch
import datetime
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import (
default_argument_parser,
default_setup,
)
from detectron2.utils.logger import setup_logger
from cubercnn.data.dataset_mapper import DatasetMapper3D
logger = logging.getLogger("scoring")
from cubercnn.config import get_cfg_defaults
from cubercnn.modeling.meta_arch import build_model, build_model_scorenet
from cubercnn import util, vis, data
# even though this import is unused, it initializes the backbone registry
from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone
# Below imports followed with do_train
from detectron2.engine import (
default_argument_parser,
default_setup,
default_writers,
launch
)
from detectron2.solver import build_lr_scheduler
from detectron2.utils.events import EventStorage
import wandb
from cubercnn.solver import build_optimizer, freeze_bn, PeriodicCheckpointerOnlyOne
from cubercnn.data import (
load_omni3d_json,
DatasetMapper3D,
build_detection_train_loader,
build_detection_test_loader,
get_omni3d_categories,
simple_register
)
from tqdm import tqdm
def do_train(cfg, model, dataset_id_to_unknown_cats, dataset_id_to_src, resume=False):
max_iter = cfg.SOLVER.MAX_ITER
do_eval = cfg.TEST.EVAL_PERIOD > 0
modelbase = model[0]
modelbase.eval()
model = model[1]
model.train()
optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)
# bookkeeping
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler)
periodic_checkpointer = PeriodicCheckpointerOnlyOne(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)
writers = default_writers(cfg.OUTPUT_DIR, max_iter)
# create the dataloader
data_mapper = DatasetMapper3D(cfg, is_train=False, mode='load_proposals')
dataset_name = cfg.DATASETS.TRAIN[0]
data_loader = build_detection_train_loader(cfg, mapper=data_mapper, dataset_id_to_src=dataset_id_to_src, num_workers=4)
# give the mapper access to dataset_ids
data_mapper.dataset_id_to_unknown_cats = dataset_id_to_unknown_cats
if cfg.MODEL.WEIGHTS_PRETRAIN != '':
# load ONLY the model, no checkpointables.
checkpointer.load(cfg.MODEL.WEIGHTS_PRETRAIN, checkpointables=[])
# determine the starting iteration, if resuming
start_iter = (checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
iteration = start_iter
logger.info("Starting training from iteration {}".format(start_iter))
if not cfg.MODEL.USE_BN:
freeze_bn(modelbase)
data_iter = iter(data_loader)
pbar = tqdm(range(start_iter, max_iter), initial=start_iter, total=max_iter, desc="Training", smoothing=0.05)
segmentor = init_segmentation(device=cfg.MODEL.DEVICE)
with EventStorage(start_iter) as storage:
while True:
data = next(data_iter)
storage.iter = iteration
# forward
combined_features = modelbase(data)
instances, loss_1, loss_2 = model(data, combined_features)
# scale the dimension L1-loss by a factor of 1000 to have both the scoring and regression losses in a similar range
loss_1 /= 2
loss_1 /= len(data)
loss_2 /= len(data)
loss_2 /= 100
total_loss = loss_1 + loss_2
# send loss scalars to tensorboard.
storage.put_scalars(total_loss=total_loss, IoU_loss=loss_1, segment_loss=loss_2)
# backward and step
total_loss.backward()
#for name, param in model.named_parameters():
# if param.grad is not None:
# print(name, param.grad)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
periodic_checkpointer.step(iteration)
# logging stuff
pbar.update(1)
pbar.set_postfix({"tot.loss": total_loss.item(), "IoU.loss": loss_1.item(), "Seg.loss": loss_2.item()})
if iteration - start_iter > 5 and ((iteration + 1) % 20 == 0 or iteration == max_iter - 1):
for writer in writers[1:]: # 3 writers; 1: prints, 2: json logs, 3: tensorboard
writer.write()
iteration += 1
if iteration >= max_iter:
break
# success
return True
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
get_cfg_defaults(cfg)
config_file = args.config_file
# store locally if needed
if config_file.startswith(util.CubeRCNNHandler.PREFIX):
config_file = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, config_file)
cfg.merge_from_file(config_file)
cfg.merge_from_list(args.opts)
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg.MODEL.DEVICE = device
cfg.SEED = 13
cfg.freeze()
default_setup(cfg, args)
setup_logger(output=cfg.OUTPUT_DIR, name="scoring")
filter_settings = data.get_filter_settings_from_cfg(cfg)
for dataset_name in cfg.DATASETS.TRAIN:
simple_register(dataset_name, filter_settings, filter_empty=True)
dataset_names_test = cfg.DATASETS.TEST
# filter_ = True if cfg.PLOT.EVAL == 'MABO' else False
for dataset_name in dataset_names_test:
if not(dataset_name in cfg.DATASETS.TRAIN):
# TODO: empties should not be filtering in test normally, or maybe they should??
simple_register(dataset_name, filter_settings, filter_empty=True)
return cfg
def main(args):
cfg = setup(args)
name = f'learned proposal {datetime.datetime.now():%Y-%m-%d %H:%M:%S%z}'
if sys.platform == 'linux':
# only log to wandb on hpc/linux
#wandb.init(project="cube", sync_tensorboard=True, name=name, config=cfg, mode='online')
True
category_path = 'output/Baseline_sgd/category_meta.json'
# store locally if needed
if category_path.startswith(util.CubeRCNNHandler.PREFIX):
category_path = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, category_path)
metadata = util.load_json(category_path)
# register the categories
thing_classes = metadata['thing_classes']
id_map = {int(key):val for key, val in metadata['thing_dataset_id_to_contiguous_id'].items()}
MetadataCatalog.get('omni3d_model').thing_classes = thing_classes
MetadataCatalog.get('omni3d_model').thing_dataset_id_to_contiguous_id = id_map
# build the model.
modelbase = build_model_scorenet(cfg, 'ScoreNetBase')
model = build_model_scorenet(cfg, 'ScoreNet')
filter_settings = data.get_filter_settings_from_cfg(cfg)
# setup and join the data.
dataset_paths = [os.path.join('datasets', 'Omni3D', name + '.json') for name in cfg.DATASETS.TRAIN]
datasets = data.Omni3D(dataset_paths, filter_settings=filter_settings)
# determine the meta data given the datasets used.
data.register_and_store_model_metadata(datasets, cfg.OUTPUT_DIR, filter_settings)
thing_classes = MetadataCatalog.get('omni3d_model').thing_classes
dataset_id_to_contiguous_id = MetadataCatalog.get('omni3d_model').thing_dataset_id_to_contiguous_id
'''
It may be useful to keep track of which categories are annotated/known
for each dataset in use, in case a method wants to use this information.
'''
infos = datasets.dataset['info']
if type(infos) == dict:
infos = [datasets.dataset['info']]
dataset_id_to_unknown_cats = {}
possible_categories = set(i for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES + 1))
dataset_id_to_src = {}
for info in infos:
dataset_id = info['id']
known_category_training_ids = set()
if not dataset_id in dataset_id_to_src:
dataset_id_to_src[dataset_id] = info['source']
for id in info['known_category_ids']:
if id in dataset_id_to_contiguous_id:
known_category_training_ids.add(dataset_id_to_contiguous_id[id])
# determine and store the unknown categories.
unknown_categories = possible_categories - known_category_training_ids
dataset_id_to_unknown_cats[dataset_id] = unknown_categories
# log the per-dataset categories
# logger.info('Available categories for {}'.format(info['name']))
# logger.info([thing_classes[i] for i in (possible_categories & known_category_training_ids)])
# DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=False)
return do_train(cfg, (modelbase, model), dataset_id_to_unknown_cats, dataset_id_to_src, resume=args.resume)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
main(args)