vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
#!/bin/bash -l
# Load environment
conda activate 3dvqgan-env
export LOGLEVEL=INFO
export PYTHONPATH=$PWD
export HYDRA_FULL_ERROR=1
# MPI example w/ 4 MPI ranks per node w/ threads spread evenly across cores (1 thread per core)
NNODES=1
NRANKS_PER_NODE=4
NTOTRANKS=$(( NNODES * NRANKS_PER_NODE ))
echo "NUM_OF_NODES= ${NNODES} TOTAL_NUM_RANKS= ${NTOTRANKS} RANKS_PER_NODE= ${NRANKS_PER_NODE}"
mpirun -np ${NTOTRANKS} \
-npernode ${NRANKS_PER_NODE} \
-x PATH \
--oversubscribe \
python train_vqgan_DDP.py \
dataset=default \
dataset.root_dir='../data/3d_grids_sample' \
model=vq_gan_3d \
model.default_root_dir_postfix='data_fm_qm9' \
model.precision=16 \
model.embedding_dim=256 \
model.n_hiddens=16 \
model.downsample=[4,4,4] \
model.num_workers=32 \
model.gradient_clip_val=1.0 \
model.lr=3e-4 \
model.discriminator_iter_start=450 \
model.perceptual_weight=4 \
model.image_gan_weight=1 \
model.gan_feat_weight=4 \
model.batch_size=1 \
model.n_codes=16384 \
model.accumulate_grad_batches=1 \
model.internal_resolution=128 \
model.checkpoint_every=1000 \
model.save_checkpoint_path='./checkpoints' \
model.resume_from_checkpoint='' \
model.max_epochs=100 \