Commit
·
c55fc6f
1
Parent(s):
a91989d
Delete train.py
Browse files
train.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import argparse
|
9 |
-
import os
|
10 |
-
import random
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import torch.backends.cudnn as cudnn
|
15 |
-
|
16 |
-
import minigpt4.tasks as tasks
|
17 |
-
from minigpt4.common.config import Config
|
18 |
-
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
19 |
-
from minigpt4.common.logger import setup_logger
|
20 |
-
from minigpt4.common.optims import (
|
21 |
-
LinearWarmupCosineLRScheduler,
|
22 |
-
LinearWarmupStepLRScheduler,
|
23 |
-
)
|
24 |
-
from minigpt4.common.registry import registry
|
25 |
-
from minigpt4.common.utils import now
|
26 |
-
|
27 |
-
# imports modules for registration
|
28 |
-
from minigpt4.datasets.builders import *
|
29 |
-
from minigpt4.models import *
|
30 |
-
from minigpt4.processors import *
|
31 |
-
from minigpt4.runners import *
|
32 |
-
from minigpt4.tasks import *
|
33 |
-
|
34 |
-
|
35 |
-
def parse_args():
|
36 |
-
parser = argparse.ArgumentParser(description="Training")
|
37 |
-
|
38 |
-
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
39 |
-
parser.add_argument(
|
40 |
-
"--options",
|
41 |
-
nargs="+",
|
42 |
-
help="override some settings in the used config, the key-value pair "
|
43 |
-
"in xxx=yyy format will be merged into config file (deprecate), "
|
44 |
-
"change to --cfg-options instead.",
|
45 |
-
)
|
46 |
-
|
47 |
-
args = parser.parse_args()
|
48 |
-
# if 'LOCAL_RANK' not in os.environ:
|
49 |
-
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
50 |
-
|
51 |
-
return args
|
52 |
-
|
53 |
-
|
54 |
-
def setup_seeds(config):
|
55 |
-
seed = config.run_cfg.seed + get_rank()
|
56 |
-
|
57 |
-
random.seed(seed)
|
58 |
-
np.random.seed(seed)
|
59 |
-
torch.manual_seed(seed)
|
60 |
-
|
61 |
-
cudnn.benchmark = False
|
62 |
-
cudnn.deterministic = True
|
63 |
-
|
64 |
-
|
65 |
-
def get_runner_class(cfg):
|
66 |
-
"""
|
67 |
-
Get runner class from config. Default to epoch-based runner.
|
68 |
-
"""
|
69 |
-
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
|
70 |
-
|
71 |
-
return runner_cls
|
72 |
-
|
73 |
-
|
74 |
-
def main():
|
75 |
-
# allow auto-dl completes on main process without timeout when using NCCL backend.
|
76 |
-
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
77 |
-
|
78 |
-
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
|
79 |
-
job_id = now()
|
80 |
-
|
81 |
-
cfg = Config(parse_args())
|
82 |
-
|
83 |
-
init_distributed_mode(cfg.run_cfg)
|
84 |
-
|
85 |
-
setup_seeds(cfg)
|
86 |
-
|
87 |
-
# set after init_distributed_mode() to only log on master.
|
88 |
-
setup_logger()
|
89 |
-
|
90 |
-
cfg.pretty_print()
|
91 |
-
|
92 |
-
task = tasks.setup_task(cfg)
|
93 |
-
datasets = task.build_datasets(cfg)
|
94 |
-
model = task.build_model(cfg)
|
95 |
-
|
96 |
-
runner = get_runner_class(cfg)(
|
97 |
-
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
98 |
-
)
|
99 |
-
runner.train()
|
100 |
-
|
101 |
-
|
102 |
-
if __name__ == "__main__":
|
103 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|