# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Wrapper to train and test a video classification model.""" from timesformer.utils.misc import launch_job from timesformer.utils.parser import load_config, parse_args from tools.test_net import test from tools.train_net import train def get_func(cfg): train_func = train test_func = test return train_func, test_func def main(): """ Main function to spawn the train and test process. """ args = parse_args() if args.num_shards > 1: args.output_dir = str(args.job_dir) cfg = load_config(args) train, test = get_func(cfg) # Perform training. if cfg.TRAIN.ENABLE: launch_job(cfg=cfg, init_method=args.init_method, func=train) # Perform multi-clip testing. if cfg.TEST.ENABLE: launch_job(cfg=cfg, init_method=args.init_method, func=test) # Perform model visualization. if cfg.TENSORBOARD.ENABLE and ( cfg.TENSORBOARD.MODEL_VIS.ENABLE or cfg.TENSORBOARD.WRONG_PRED_VIS.ENABLE ): launch_job(cfg=cfg, init_method=args.init_method, func=visualize) if __name__ == "__main__": main()