File size: 28,820 Bytes
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
import os, torch
import os.path as osp
import shutil
import numpy as np
import copy
import torch.nn as nn
from tqdm.auto import tqdm
from accelerate import Accelerator
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader, DistributedSampler
from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
from accelerate.utils import DistributedDataParallelKwargs
from semanticist.stage2.gpt import GPT_models
from semanticist.stage2.generate import generate
from pathlib import Path
import time

from semanticist.engine.trainer_utils import (
    instantiate_from_config, concat_all_gather,
    save_img_batch, get_fid_stats,
    EMAModel, create_scheduler, load_state_dict, load_safetensors,
    setup_result_folders, create_optimizer,
    CacheDataLoader
)

class GPTTrainer(nn.Module):
    def __init__(
        self,
        ae_model,
        gpt_model,
        dataset,
        test_only=False,
        num_test_images=50000,
        num_epoch=400,
        eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
        blr=1e-4,
        cosine_lr=False,
        lr_min=0,
        warmup_epochs=100,
        warmup_steps=None,
        warmup_lr_init=0,
        decay_steps=None,
        batch_size=32,
        cache_bs=8,
        test_bs=100,
        num_workers=8,
        pin_memory=False,
        max_grad_norm=None,
        grad_accum_steps=1,
        precision='bf16',
        save_every=10000,
        sample_every=1000,
        fid_every=50000,
        result_folder=None,
        log_dir="./log",
        ae_cfg=1.0,
        cfg=6.0,
        cfg_schedule="linear",
        temperature=1.0,
        train_num_slots=None,
        test_num_slots=None,
        eval_fid=False,
        fid_stats=None,
        enable_ema=False,
        compile=False,
        enable_cache_latents=True,
        cache_dir='/dev/shm/slot_cache'
    ):
        super().__init__()
        kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
        self.accelerator = Accelerator(
            kwargs_handlers=[kwargs],
            mixed_precision=precision,
            gradient_accumulation_steps=grad_accum_steps,
            log_with="tensorboard",
            project_dir=log_dir,
        )

        self.ae_model = instantiate_from_config(ae_model)
        ae_model_path = ae_model.params.ckpt_path
        assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
        assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
        self._load_checkpoint(ae_model_path, self.ae_model)

        self.ae_model.to(self.device)
        for param in self.ae_model.parameters():
            param.requires_grad = False
        self.ae_model.eval()

        self.model_name = gpt_model.target
        if 'GPT' in gpt_model.target:
            self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
        else:
            raise ValueError(f"Unknown model type: {gpt_model.target}")
        self.num_slots = ae_model.params.num_slots
        self.slot_dim = ae_model.params.slot_dim

        self.test_only = test_only
        self.test_bs = test_bs
        self.num_test_images = num_test_images
        self.num_classes = gpt_model.params.num_classes
        self.batch_size = batch_size
        if not test_only:
            self.train_ds = instantiate_from_config(dataset)
            train_size = len(self.train_ds)
            if self.accelerator.is_main_process:
                print(f"train dataset size: {train_size}")

            sampler = DistributedSampler(
                self.train_ds,
                num_replicas=self.accelerator.num_processes,
                rank=self.accelerator.process_index,
                shuffle=True,
            )
            self.train_dl = DataLoader(
                self.train_ds,
                batch_size=batch_size if not enable_cache_latents else cache_bs,
                sampler=sampler,
                num_workers=num_workers,
                pin_memory=pin_memory,
                drop_last=True,
            )

            effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
            lr = blr * effective_bs / 256
            if self.accelerator.is_main_process:
                print(f"Effective batch size is {effective_bs}")

            self.g_optim = create_optimizer(self.gpt_model, weight_decay=0.05, learning_rate=lr)
            
            if warmup_epochs is not None:
                warmup_steps = warmup_epochs * len(self.train_dl)
                
            self.g_sched = create_scheduler(
                self.g_optim,
                num_epoch,
                len(self.train_dl),
                lr_min,
                warmup_steps,
                warmup_lr_init,
                decay_steps,
                cosine_lr
            )
            self.accelerator.register_for_checkpointing(self.g_sched)
            self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(self.gpt_model, self.g_optim, self.g_sched)
        else:
            self.gpt_model = self.accelerator.prepare(self.gpt_model)

        self.steps = 0
        self.loaded_steps = -1

        if compile:
            self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
            _model = self.accelerator.unwrap_model(self.gpt_model)
            _model = torch.compile(_model, mode="reduce-overhead")

        self.enable_ema = enable_ema
        if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
            self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
            self.accelerator.register_for_checkpointing(self.ema_model)

        self._load_checkpoint(gpt_model.params.ckpt_path)
        if self.test_only:
            self.steps = self.loaded_steps

        self.num_epoch = num_epoch
        self.save_every = save_every
        self.sample_every = sample_every
        self.fid_every = fid_every
        self.max_grad_norm = max_grad_norm
        self.eval_classes = eval_classes
        self.cfg = cfg
        self.ae_cfg = ae_cfg
        self.cfg_schedule = cfg_schedule
        self.temperature = temperature
        self.train_num_slots = train_num_slots
        self.test_num_slots = test_num_slots
        if self.train_num_slots is not None:
            self.train_num_slots = min(self.train_num_slots, self.num_slots)
        else:
            self.train_num_slots = self.num_slots
        if self.test_num_slots is not None:
            self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
        else:
            self.num_slots_to_gen = self.train_num_slots
        self.eval_fid = eval_fid
        if eval_fid:
            assert fid_stats is not None
        self.fid_stats = fid_stats

        # Setup result folders
        self.result_folder = result_folder
        self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)

        # Setup cache
        self.cache_dir = Path(cache_dir)
        self.enable_cache_latents = enable_cache_latents
        self.cache_loader = None

    @property
    def device(self):
        return self.accelerator.device

    def _load_checkpoint(self, ckpt_path=None, model=None):
        if ckpt_path is None or not osp.exists(ckpt_path):
            return

        if model is None:
            model = self.accelerator.unwrap_model(self.gpt_model)

        if osp.isdir(ckpt_path):
            self.loaded_steps = int(
                ckpt_path.split("step")[-1].split("/")[0]
            )
            if not self.test_only:
                self.accelerator.load_state(ckpt_path)
            else:
                if self.enable_ema:
                    model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
                    if osp.exists(model_path):
                        state_dict = torch.load(model_path, map_location="cpu")
                        load_state_dict(state_dict, model)
                        if self.accelerator.is_main_process:
                            print(f"Loaded ema model from {model_path}")
                else:
                    model_path = osp.join(ckpt_path, "model.safetensors")
                    if osp.exists(model_path):
                        load_safetensors(model_path, model)
        else:
            if ckpt_path.endswith(".safetensors"):
                load_safetensors(ckpt_path, model)
            else:
                state_dict = torch.load(ckpt_path, map_location="cpu")
                load_state_dict(state_dict, model)
        if self.accelerator.is_main_process:
            print(f"Loaded checkpoint from {ckpt_path}")

    def _build_cache(self):
        """Build cache for slots and targets."""
        rank = self.accelerator.process_index
        world_size = self.accelerator.num_processes
        
        # Clean up any existing cache files first
        slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
        targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
        
        if slots_file.exists():
            os.remove(slots_file)
        if targets_file.exists():
            os.remove(targets_file)
        
        dataset_size = len(self.train_dl.dataset)
        shard_size = dataset_size // world_size
        
        # Detect number of augmentations from first batch
        with torch.no_grad():
            sample_batch = next(iter(self.train_dl))
            img, _ = sample_batch
            num_augs = img.shape[1] if len(img.shape) == 5 else 1
        
        print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
        os.makedirs(self.cache_dir, exist_ok=True)
        slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
        targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
        
        # Create memory-mapped files
        slots_mmap = np.memmap(
            slots_file,
            dtype='float32',
            mode='w+',
            shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
        )
        
        targets_mmap = np.memmap(
            targets_file,
            dtype='int64',
            mode='w+',
            shape=(shard_size * num_augs,)
        )
        
        # Cache data
        with torch.no_grad():
            for i, batch in enumerate(tqdm(
                self.train_dl, 
                desc=f"Rank {rank}: Caching data",
                disable=not self.accelerator.is_local_main_process
            )):
                imgs, targets = batch
                if len(imgs.shape) == 5:  # [B, num_augs, C, H, W]
                    B, A, C, H, W = imgs.shape
                    imgs = imgs.view(-1, C, H, W)  # [B*num_augs, C, H, W]
                    targets = targets.unsqueeze(1).expand(-1, A).reshape(-1)  # [B*num_augs]
                
                # Split imgs into n chunks
                num_splits = num_augs
                split_size = imgs.shape[0] // num_splits
                imgs_splits = torch.split(imgs, split_size)
                targets_splits = torch.split(targets, split_size)
                
                start_idx = i * self.train_dl.batch_size * num_augs
                
                for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
                    img_split = img_split.to(self.device, non_blocking=True)
                    slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
                    
                    split_start = start_idx + (split_idx * split_size)
                    split_end = split_start + img_split.shape[0]
                    
                    # Write directly to mmap files
                    slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
                    targets_mmap[split_start:split_end] = targets_split.numpy()
        
        # Close the mmap files
        del slots_mmap
        del targets_mmap
        
        # Reopen in read mode
        self.cached_latents = np.memmap(
            slots_file,
            dtype='float32',
            mode='r',
            shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
        )
        
        self.cached_targets = np.memmap(
            targets_file,
            dtype='int64',
            mode='r',
            shape=(shard_size * num_augs,)
        )
        
        # Store the number of augmentations for the cache loader
        self.num_augs = num_augs

    def _setup_cache(self):
        """Setup cache if enabled."""
        self._build_cache()
        self.accelerator.wait_for_everyone()

        # Initialize cache loader if cache exists
        if self.cached_latents is not None:
            self.cache_loader = CacheDataLoader(
                slots=self.cached_latents,
                targets=self.cached_targets,
                batch_size=self.batch_size,
                num_augs=self.num_augs,
                seed=42 + self.accelerator.process_index
            )

    def __del__(self):
        """Cleanup cache files."""
        if self.enable_cache_latents:
            rank = self.accelerator.process_index
            world_size = self.accelerator.num_processes
            
            # Clean up slots cache
            slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
            if slots_file.exists():
                os.remove(slots_file)
            
            # Clean up targets cache
            targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
            if targets_file.exists():
                os.remove(targets_file)

    def _train_step(self, slots, targets=None):
        """Execute single training step."""
        
        with self.accelerator.accumulate(self.gpt_model):
            with self.accelerator.autocast():
                loss = self.gpt_model(slots, targets)
            
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients and self.max_grad_norm is not None:
                self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
            self.g_optim.step()
            if self.g_sched is not None:
                self.g_sched.step_update(self.steps)
            self.g_optim.zero_grad()

        # Update EMA model if enabled
        if self.enable_ema:
            self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
        
        return loss

    def _train_epoch_cached(self, epoch, logger):
        """Train one epoch using cached data."""
        self.cache_loader.set_epoch(epoch)
        header = f'Epoch: [{epoch}/{self.num_epoch}]'
        
        for batch in logger.log_every(self.cache_loader, 20, header):
            slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
            
            self.steps += 1

            if self.steps == 1:
                print(f"Training batch size: {len(slots)}")
                print(f"Hello from index {self.accelerator.local_process_index}")
            
            loss = self._train_step(slots, targets)
            self._handle_periodic_ops(loss, logger)

    def _train_epoch_uncached(self, epoch, logger):
        """Train one epoch using raw data."""
        header = f'Epoch: [{epoch}/{self.num_epoch}]'
        
        for batch in logger.log_every(self.train_dl, 20, header):
            img, targets = (b.to(self.device, non_blocking=True) for b in batch)
            
            self.steps += 1
            
            if self.steps == 1:
                print(f"Training batch size: {img.size(0)}")
                print(f"Hello from index {self.accelerator.local_process_index}")

            slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
            loss = self._train_step(slots, targets)
            self._handle_periodic_ops(loss, logger)

    def _handle_periodic_ops(self, loss, logger):
        """Handle periodic operations and logging."""
        logger.update(loss=loss.item())
        logger.update(lr=self.g_optim.param_groups[0]["lr"])
        
        if self.steps % self.save_every == 0:
            self.save()
        
        if (self.steps % self.sample_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
            empty_cache()
            self.evaluate()
            self.accelerator.wait_for_everyone()
            empty_cache()

    def _save_config(self, config):
        """Save configuration file."""
        if config is not None and self.accelerator.is_main_process:
            import shutil
            from omegaconf import OmegaConf

            if isinstance(config, str) and osp.exists(config):
                shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
            else:
                config_save_path = osp.join(self.result_folder, "config.yaml")
                OmegaConf.save(config, config_save_path)

    def _should_skip_epoch(self, epoch):
        """Check if epoch should be skipped due to loaded checkpoint."""
        loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
        if ((epoch + 1) * len(loader)) <= self.loaded_steps:
            if self.accelerator.is_main_process:
                print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
            self.steps += len(loader)
            return True
        
        if self.steps < self.loaded_steps:
            for _ in loader:
                self.steps += 1
                if self.steps >= self.loaded_steps:
                    break
        return False

    def train(self, config=None):
        """Main training loop."""
        # Initial setup
        n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
        if self.accelerator.is_main_process:
            print(f"number of learnable parameters: {n_parameters//1e6}M")
        
        self._save_config(config)
        self.accelerator.init_trackers("gpt")

        # Handle test-only mode
        if self.test_only:
            empty_cache()
            self.evaluate()
            self.accelerator.wait_for_everyone()
            empty_cache()
            return

        # Setup cache if enabled
        if self.enable_cache_latents:
            self._setup_cache()

        # Training loop
        for epoch in range(self.num_epoch):
            if self._should_skip_epoch(epoch):
                continue

            self.gpt_model.train()
            logger = MetricLogger(delimiter="  ")
            logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))

            # Choose training path based on cache availability
            if self.enable_cache_latents:
                self._train_epoch_cached(epoch, logger)
            else:
                self._train_epoch_uncached(epoch, logger)

            # Synchronize and log epoch stats
            logger.synchronize_between_processes()
            if self.accelerator.is_main_process:
                print("Averaged stats:", logger)

        # Finish training
        self.accelerator.end_training()
        self.save()
        if self.accelerator.is_main_process:
            print("Train finished!")

    def save(self):
        self.accelerator.wait_for_everyone()
        self.accelerator.save_state(
            os.path.join(self.model_saved_dir, f"step{self.steps}")
        )

    @torch.no_grad()
    def evaluate(self, use_ema=True):
        self.gpt_model.eval()
        unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
        # switch to ema params, only when eval_fid is True
        # if test_only, we directly load the ema dict and skip here
        use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
        if use_ema:
            if hasattr(self, "ema_model"):
                model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
                model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
                ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
                for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
                    if "nested_sampler" in name:
                        continue
                    ema_state_dict[name] = self.ema_model.state_dict()[name]
                if self.accelerator.is_main_process:
                    print("Switch to ema")
                model_without_ddp.load_state_dict(ema_state_dict)
            else:
                print("EMA model not found, using original model")
                use_ema = False

        if not self.test_only:
            classes = torch.tensor(self.eval_classes, device=self.device)
            with self.accelerator.autocast():
                slots = generate(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, cfg_schedule=self.cfg_schedule, temperature=self.temperature)
                if self.num_slots_to_gen < self.num_slots:
                    null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
                    null_slots = null_slots[:, self.num_slots_to_gen:, :]
                    slots = torch.cat([slots, null_slots], dim=1)
                imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now

            imgs = concat_all_gather(imgs)
            if self.accelerator.num_processes > 16:
                imgs = imgs[:16*len(self.eval_classes)]
            imgs = imgs.detach().cpu()
            grid = make_grid(
                imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
            )
            if self.accelerator.is_main_process:
                save_image(
                    grid,
                    os.path.join(
                        self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
                    ),
                )
        if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
            # Create output directory (only on main process)
            save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
            if self.accelerator.is_main_process:
                os.makedirs(save_folder, exist_ok=True)

            # Setup for distributed generation
            world_size = self.accelerator.num_processes
            local_rank = self.accelerator.process_index
            batch_size = self.test_bs
            
            # Create balanced class distribution
            num_classes = self.num_classes
            images_per_class = self.num_test_images // num_classes
            class_labels = np.repeat(np.arange(num_classes), images_per_class)
            
            # Shuffle the class labels to ensure random ordering
            np.random.shuffle(class_labels)
            
            total_images = len(class_labels)

            padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
            class_labels = np.pad(class_labels, (0, padding_size), 'constant')
            padded_total_images = len(class_labels)

            # Distribute workload across GPUs
            images_per_gpu = padded_total_images // world_size
            start_idx = local_rank * images_per_gpu
            end_idx = min(start_idx + images_per_gpu, padded_total_images)
            local_class_labels = class_labels[start_idx:end_idx]
            local_num_steps = len(local_class_labels) // batch_size
            
            if self.accelerator.is_main_process:
                print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
            
            used_time = 0
            gen_img_cnt = 0
            
            for i in range(local_num_steps):
                if self.accelerator.is_main_process and i % 10 == 0:
                    print(f"Generation step {i}/{local_num_steps}")

                # Get and pad labels for current batch
                batch_start = i * batch_size
                batch_end = batch_start + batch_size
                labels = local_class_labels[batch_start:batch_end]
                
                # Convert to tensors and track real vs padding
                labels = torch.tensor(labels, device=self.device)
                
                # Generate images
                self.accelerator.wait_for_everyone()
                start_time = time.time()
                with torch.no_grad():
                    with self.accelerator.autocast():
                        slots = generate(unwraped_gpt_model, labels, self.num_slots_to_gen, 
                                            cfg_scale=self.cfg, 
                                            cfg_schedule=self.cfg_schedule,
                                            temperature=self.temperature)
                        if self.num_slots_to_gen < self.num_slots:
                            null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
                            null_slots = null_slots[:, self.num_slots_to_gen:, :]
                            slots = torch.cat([slots, null_slots], dim=1)
                        imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)

                samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)

                # Update timing stats
                used_time += time.time() - start_time
                gen_img_cnt += samples_in_batch
                if self.accelerator.is_main_process and i % 10 == 0:
                    print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")

                gathered_imgs = concat_all_gather(imgs)
                gathered_imgs = gathered_imgs[:samples_in_batch]
                
                # Save images (only on main process)
                if self.accelerator.is_main_process:
                    real_imgs = gathered_imgs.detach().cpu()
                    
                    save_paths = [
                        os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
                        for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
                    ]
                    save_img_batch(real_imgs, save_paths)

            # Calculate metrics (only on main process)
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                generated_files = len(os.listdir(save_folder))
                print(f"Generated {generated_files} images out of {total_images} expected")
                
                metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
                fid = metrics_dict["frechet_inception_distance"]
                inception_score = metrics_dict["inception_score_mean"]
                
                metric_prefix = "fid_ema" if use_ema else "fid"
                isc_prefix = "isc_ema" if use_ema else "isc"
                
                self.accelerator.log({
                    metric_prefix: fid,
                    isc_prefix: inception_score,
                    "gpt_cfg": self.cfg,
                    "ae_cfg": self.ae_cfg,
                    "cfg_schedule": self.cfg_schedule,
                    "temperature": self.temperature,
                    "num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
                }, step=self.steps)
                
                # Print comprehensive CFG information
                cfg_info = (
                    f"{'EMA ' if use_ema else ''}CFG params: "
                    f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, "
                    f"cfg_schedule={self.cfg_schedule}, "
                    f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
                    f"temperature={self.temperature}"
                )
                print(cfg_info)
                print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")

                # Cleanup
                shutil.rmtree(save_folder)

        # back to no ema
        if use_ema:
            if self.accelerator.is_main_process:
                print("Switch back from ema")
            model_without_ddp.load_state_dict(model_state_dict)

        self.gpt_model.train()