model:
  base_learning_rate: 1.0e-4
  target: sgm.models.diffusion.DiffusionEngine
  params:
    scale_factor: 0.13025
    disable_first_stage_autocast: True
    log_keys:
      - txt

    scheduler_config:
      target: sgm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [10000]
        cycle_lengths: [10000000000000]
        f_start: [1.e-6]
        f_max: [1.]
        f_min: [1.]

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
      params:
        num_idx: 1000

        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        use_checkpoint: True
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [1, 2, 4]
        num_res_blocks: 2
        channel_mult: [1, 2, 4, 4]
        num_head_channels: 64
        num_classes: sequential
        adm_in_channels: 1792
        num_heads: 1
        transformer_depth: 1
        context_dim: 768
        spatial_transformer_attn_type: softmax-xformers

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          - is_trainable: True
            input_key: txt
            ucg_rate: 0.1
            legacy_ucg_value: ""
            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
            params:
              always_return_pooled: True

          - is_trainable: False
            ucg_rate: 0.1
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - is_trainable: False
            input_key: crop_coords_top_left
            ucg_rate: 0.1
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

    first_stage_config:
      target: sgm.models.autoencoder.AutoencoderKL
      params:
        ckpt_path: CKPT_PATH
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          attn_type: vanilla-xformers
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [1, 2, 4, 4]
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
          params:
            num_idx: 1000

            discretization_config:
              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 7.5

data:
  target: sgm.data.dataset.StableDataModuleFromConfig
  params:
    train:
      datapipeline:
        urls:
          # USER: adapt this path the root of your custom dataset
          - DATA_PATH
        pipeline_config:
          shardshuffle: 10000
          sample_shuffle: 10000


        decoders:
          - pil

        postprocessors:
          - target: sdata.mappers.TorchVisionImageTransforms
            params:
              key: jpg # USER: you might wanna adapt this for your custom dataset
              transforms:
                - target: torchvision.transforms.Resize
                  params:
                    size: 256
                    interpolation: 3
                - target: torchvision.transforms.ToTensor
          - target: sdata.mappers.Rescaler
            # USER: you might wanna use non-default parameters due to your custom dataset
          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
            # USER: you might wanna use non-default parameters due to your custom dataset

      loader:
        batch_size: 64
        num_workers: 6

lightning:
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 25000

    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        enable_autocast: False
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          N: 8
          n_rows: 2

  trainer:
    devices: 0,
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 1000