Spaces:
Running
on
L4
Running
on
L4
File size: 2,460 Bytes
0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Base configuration for training a model
paths:
run_dir: results/${project}
ckpt_dir: ${paths.run_dir}/checkpoints
hydra:
run:
dir: ${paths.run_dir}
# Lightning Trainer
trainer:
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.run_dir}
accelerator: gpu
num_nodes: 1
devices: auto
strategy:
_target_: lightning.pytorch.strategies.DDPStrategy
precision: bf16-mixed
# disable validation by epoch end
check_val_every_n_epoch: null
val_check_interval: 5000
max_steps: 100_000
# Use torch.backends.cudnn.benchmark to speed up training
benchmark: true
# Callbacks
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.ckpt_dir}
filename: "step_{step:09d}"
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 5 # save 5 latest checkpoints
monitor: step # use step to monitor checkpoints
mode: max # save the latest checkpoint with the highest global_step
every_n_epochs: null # don't save checkpoints by epoch end
every_n_train_steps: 5000 # save checkpoints every 5000 steps
auto_insert_metric_name: false
model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 2 # the maximum depth of layer nesting that the summary will include
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: false
grad_norm_monitor:
_target_: fish_speech.callbacks.GradNormMonitor
norm_type: 2
logging_interval: step
# Logger
logger:
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.run_dir}/tensorboard/"
name: null
log_graph: false
default_hp_metric: true
prefix: ""
# wandb:
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
# # name: "" # name of the run (normally generated by wandb)
# save_dir: "${paths.run_dir}"
# offline: False
# id: null # pass correct id to resume experiment!
# anonymous: null # enable anonymous logging
# project: "fish-speech"
# log_model: False # upload lightning ckpts
# prefix: "" # a string to put at the beginning of metric keys
# # entity: "" # set to name of your wandb team
# group: ""
# tags: ["vq", "hq", "finetune"]
# job_type: ""
# Loop
train: true
test: false
|