|
import os |
|
import argparse |
|
import sys |
|
import torch |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '1' |
|
os.environ['DETECTRON2_DATASETS'] = '/ccn2/u/honglinc/datasets' |
|
|
|
|
|
MASK2FORMER_PATH = '/ccn2/u/honglinc/Mask2Former' |
|
BBNET_PATH = '/home/honglinc/BBNet' |
|
sys.path.append(os.path.join(BBNET_PATH, 'bbnet/models/VideoMAE-main/')) |
|
sys.path.append(BBNET_PATH) |
|
sys.path.append(MASK2FORMER_PATH) |
|
|
|
|
|
import modeling_pretrain as vmae_tranformers |
|
from evaluate_segmentation_readout_helper_v2 import CWMSegmentPredictorV2 |
|
|
|
import detectron2.utils.comm as comm |
|
from detectron2.evaluation import verify_results |
|
from train_net import setup, Trainer, DetectionCheckpointer |
|
from detectron2.engine import default_argument_parser, launch |
|
|
|
def main(args): |
|
cfg = setup(args) |
|
|
|
if args.eval_only: |
|
model = Trainer.build_model(cfg) |
|
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( |
|
cfg.MODEL.WEIGHTS, resume=args.resume |
|
) |
|
res = Trainer.test(cfg, model) |
|
if cfg.TEST.AUG.ENABLED: |
|
res.update(Trainer.test_with_TTA(cfg, model)) |
|
if comm.is_main_process(): |
|
verify_results(cfg, res) |
|
return res |
|
|
|
trainer = Trainer(cfg) |
|
trainer.resume_or_load(resume=args.resume) |
|
return trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = default_argument_parser().parse_args() |
|
print("Command Line Args:", args) |
|
launch( |
|
main, |
|
args.num_gpus, |
|
num_machines=args.num_machines, |
|
machine_rank=args.machine_rank, |
|
dist_url=args.dist_url, |
|
args=(args,), |
|
) |