Vision-CAIR commited on
Commit
c55fc6f
·
1 Parent(s): a91989d

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -103
train.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- Copyright (c) 2022, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
-
8
- import argparse
9
- import os
10
- import random
11
-
12
- import numpy as np
13
- import torch
14
- import torch.backends.cudnn as cudnn
15
-
16
- import minigpt4.tasks as tasks
17
- from minigpt4.common.config import Config
18
- from minigpt4.common.dist_utils import get_rank, init_distributed_mode
19
- from minigpt4.common.logger import setup_logger
20
- from minigpt4.common.optims import (
21
- LinearWarmupCosineLRScheduler,
22
- LinearWarmupStepLRScheduler,
23
- )
24
- from minigpt4.common.registry import registry
25
- from minigpt4.common.utils import now
26
-
27
- # imports modules for registration
28
- from minigpt4.datasets.builders import *
29
- from minigpt4.models import *
30
- from minigpt4.processors import *
31
- from minigpt4.runners import *
32
- from minigpt4.tasks import *
33
-
34
-
35
- def parse_args():
36
- parser = argparse.ArgumentParser(description="Training")
37
-
38
- parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
39
- parser.add_argument(
40
- "--options",
41
- nargs="+",
42
- help="override some settings in the used config, the key-value pair "
43
- "in xxx=yyy format will be merged into config file (deprecate), "
44
- "change to --cfg-options instead.",
45
- )
46
-
47
- args = parser.parse_args()
48
- # if 'LOCAL_RANK' not in os.environ:
49
- # os.environ['LOCAL_RANK'] = str(args.local_rank)
50
-
51
- return args
52
-
53
-
54
- def setup_seeds(config):
55
- seed = config.run_cfg.seed + get_rank()
56
-
57
- random.seed(seed)
58
- np.random.seed(seed)
59
- torch.manual_seed(seed)
60
-
61
- cudnn.benchmark = False
62
- cudnn.deterministic = True
63
-
64
-
65
- def get_runner_class(cfg):
66
- """
67
- Get runner class from config. Default to epoch-based runner.
68
- """
69
- runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
70
-
71
- return runner_cls
72
-
73
-
74
- def main():
75
- # allow auto-dl completes on main process without timeout when using NCCL backend.
76
- # os.environ["NCCL_BLOCKING_WAIT"] = "1"
77
-
78
- # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
79
- job_id = now()
80
-
81
- cfg = Config(parse_args())
82
-
83
- init_distributed_mode(cfg.run_cfg)
84
-
85
- setup_seeds(cfg)
86
-
87
- # set after init_distributed_mode() to only log on master.
88
- setup_logger()
89
-
90
- cfg.pretty_print()
91
-
92
- task = tasks.setup_task(cfg)
93
- datasets = task.build_datasets(cfg)
94
- model = task.build_model(cfg)
95
-
96
- runner = get_runner_class(cfg)(
97
- cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
98
- )
99
- runner.train()
100
-
101
-
102
- if __name__ == "__main__":
103
- main()