|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import torch |
|
import logging |
|
|
|
import random |
|
import numpy as np |
|
|
|
from utilities.arguments import load_opt_command |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int = 42) -> None: |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
print(f"Random seed set as {seed}") |
|
|
|
def main(args=None): |
|
''' |
|
[Main function for the entry point] |
|
1. Set environment variables for distributed training. |
|
2. Load the config file and set up the trainer. |
|
''' |
|
|
|
opt, cmdline_args = load_opt_command(args) |
|
command = cmdline_args.command |
|
|
|
if cmdline_args.user_dir: |
|
absolute_user_dir = os.path.abspath(cmdline_args.user_dir) |
|
opt['base_path'] = absolute_user_dir |
|
|
|
|
|
world_size = 1 |
|
if 'OMPI_COMM_WORLD_SIZE' in os.environ: |
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
|
|
|
if opt['TRAINER'] == 'xdecoder': |
|
from trainer import XDecoder_Trainer as Trainer |
|
else: |
|
assert False, "The trainer type: {} is not defined!".format(opt['TRAINER']) |
|
|
|
set_seed(opt['RANDOM_SEED']) |
|
|
|
trainer = Trainer(opt) |
|
os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL' |
|
|
|
if command == "train": |
|
|
|
|
|
|
|
trainer.train() |
|
elif command == "evaluate": |
|
trainer.eval() |
|
else: |
|
raise ValueError(f"Unknown command: {command}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
sys.exit(0) |
|
|