teticio commited on
Commit
fdb8eec
1 Parent(s): 249ec3b

fix when save_model_epochs != save_images_epochs

Browse files
Files changed (1) hide show
  1. scripts/train_unconditional.py +9 -7
scripts/train_unconditional.py CHANGED
@@ -176,11 +176,11 @@ def main(args):
176
 
177
  if args.push_to_hub:
178
  if args.hub_model_id is None:
179
- repo_name = get_full_repo_name(Path(args.output_dir).name,
180
  token=args.hub_token)
181
  else:
182
  repo_name = args.hub_model_id
183
- repo = Repository(args.output_dir, clone_from=repo_name)
184
 
185
  if accelerator.is_main_process:
186
  run = os.path.split(__file__)[-1].split(".")[0]
@@ -270,9 +270,9 @@ def main(args):
270
 
271
  # Generate sample images for visual inspection
272
  if accelerator.is_main_process:
273
- if (
274
  epoch + 1
275
- ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
276
  pipeline = AudioDiffusionPipeline(
277
  vqvae=vqvae,
278
  unet=accelerator.unwrap_model(
@@ -280,15 +280,17 @@ def main(args):
280
  mel=mel,
281
  scheduler=noise_scheduler,
282
  )
283
- pipeline.save_pretrained(args.output_dir)
 
 
 
 
284
 
285
  # save the model
286
  if args.push_to_hub:
287
  repo.push_to_hub(commit_message=f"Epoch {epoch}",
288
  blocking=False,
289
  auto_lfs_prune=True)
290
- else:
291
- pipeline.save_pretrained(output_dir)
292
 
293
  if (epoch + 1) % args.save_images_epochs == 0:
294
  generator = torch.Generator(
 
176
 
177
  if args.push_to_hub:
178
  if args.hub_model_id is None:
179
+ repo_name = get_full_repo_name(Path(output_dir).name,
180
  token=args.hub_token)
181
  else:
182
  repo_name = args.hub_model_id
183
+ repo = Repository(output_dir, clone_from=repo_name)
184
 
185
  if accelerator.is_main_process:
186
  run = os.path.split(__file__)[-1].split(".")[0]
 
270
 
271
  # Generate sample images for visual inspection
272
  if accelerator.is_main_process:
273
+ if (epoch + 1) % args.save_model_epochs == 0 or (
274
  epoch + 1
275
+ ) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
276
  pipeline = AudioDiffusionPipeline(
277
  vqvae=vqvae,
278
  unet=accelerator.unwrap_model(
 
280
  mel=mel,
281
  scheduler=noise_scheduler,
282
  )
283
+
284
+ if (
285
+ epoch + 1
286
+ ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
287
+ pipeline.save_pretrained(output_dir)
288
 
289
  # save the model
290
  if args.push_to_hub:
291
  repo.push_to_hub(commit_message=f"Epoch {epoch}",
292
  blocking=False,
293
  auto_lfs_prune=True)
 
 
294
 
295
  if (epoch + 1) % args.save_images_epochs == 0:
296
  generator = torch.Generator(