Anymate / Train.py
yfdeng's picture
init
744eb4e
import os
import shutil
import argparse
import torch
import torch.multiprocessing as mp
from Anymate.utils.train_utils import train_model
import yaml
from Anymate.dataset import AnymateDataset
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyG DGCNN')
parser.add_argument('--config', type=str, default='joints', help='load decoder')
parser.add_argument('--split', action='store_true', help='use split dataset')
args = parser.parse_args()
world_size = torch.cuda.device_count()
print('world_size', world_size)
#load config file
config_folder = './Anymate/configs'
assert os.path.exists(os.path.join(config_folder, args.config+'.yaml')), f"Config file {os.path.join(config_folder, args.config+'.yaml')} not found"
with open(os.path.join(config_folder, args.config+'.yaml')) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
for key, value in config['args'].items():
setattr(args, key, value)
setattr(args, 'decoder', config['model']['decoder'])
args.logdir = os.path.join(args.logdir, args.mode + '-' + config['model']['encoder']+ '-' + config['model']['decoder'])
args.checkpoint = os.path.join(args.checkpoint, args.mode + '-' + config['model']['encoder']+ '-' + config['model']['decoder'])
print(args)
# create checkpoint dir and log dir
if not os.path.isdir(args.checkpoint):
print("Create new checkpoint folder " + args.checkpoint)
os.makedirs(args.checkpoint, exist_ok=True)
if not args.resume:
if os.path.isdir(args.logdir):
shutil.rmtree(args.logdir)
os.makedirs(args.logdir, exist_ok=True)
else:
os.makedirs(args.logdir, exist_ok=True)
global train_dataset
if not args.split:
# create a shared memory dataset dictionary
train_dataset = AnymateDataset(name=args.trainset, root=args.root)
train_dataset.data_list = [data for data in train_dataset.data_list if data['vox'].shape[0] != 0]
print('train_dataset', len(train_dataset.data_list))
import multiprocessing
manager = multiprocessing.Manager()
shared_dict = manager.dict()
shared_dict['train_dataset'] = train_dataset
else:
shared_dict = None
# Try different ports until we find a free one
port = 12355
while port < 65535: # Max port number
try:
mp.spawn(train_model, args=(world_size, config, args, shared_dict, port), nprocs=world_size)
break
except Exception as e:
if "address already in use" in str(e).lower():
print(f"Port {port} is already in use, trying next port")
port += 1
else:
print(f"Error starting training on port {port}: {e}")
raise e
print(f"Successfully started training on port {port}")