|
#!/bin/bash -l |
|
|
|
|
|
conda activate 3dvqgan-env |
|
|
|
export LOGLEVEL=INFO |
|
export PYTHONPATH=$PWD |
|
export HYDRA_FULL_ERROR=1 |
|
|
|
|
|
NNODES=1 |
|
NRANKS_PER_NODE=4 |
|
|
|
NTOTRANKS=$(( NNODES * NRANKS_PER_NODE )) |
|
echo "NUM_OF_NODES= ${NNODES} TOTAL_NUM_RANKS= ${NTOTRANKS} RANKS_PER_NODE= ${NRANKS_PER_NODE}" |
|
|
|
mpirun -np ${NTOTRANKS} \ |
|
-npernode ${NRANKS_PER_NODE} \ |
|
-x PATH \ |
|
--oversubscribe \ |
|
python train_vqgan_DDP.py \ |
|
dataset=default \ |
|
dataset.root_dir='../data/3d_grids_sample' \ |
|
model=vq_gan_3d \ |
|
model.default_root_dir_postfix='data_fm_qm9' \ |
|
model.precision=16 \ |
|
model.embedding_dim=256 \ |
|
model.n_hiddens=16 \ |
|
model.downsample=[4,4,4] \ |
|
model.num_workers=32 \ |
|
model.gradient_clip_val=1.0 \ |
|
model.lr=3e-4 \ |
|
model.discriminator_iter_start=450 \ |
|
model.perceptual_weight=4 \ |
|
model.image_gan_weight=1 \ |
|
model.gan_feat_weight=4 \ |
|
model.batch_size=1 \ |
|
model.n_codes=16384 \ |
|
model.accumulate_grad_batches=1 \ |
|
model.internal_resolution=128 \ |
|
model.checkpoint_every=1000 \ |
|
model.save_checkpoint_path='./checkpoints' \ |
|
model.resume_from_checkpoint='' \ |
|
model.max_epochs=100 \ |