File size: 2,893 Bytes
8cd00a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
import io
import tarfile
import torch
import webdataset as wds
import numpy as np

from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline

import datetime
from datasets import load_dataset
from torch.utils.data import DataLoader
import diffusers
import fire

def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
    blocks_to_save = [
        'unet.down_blocks.2.attentions.1',
        'unet.mid_block.attentions.0',
        'unet.up_blocks.0.attentions.0',
        'unet.up_blocks.0.attentions.1',
    ]

    # Initialization
    dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
    pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
    pipe.to('cuda')
    pipe.set_progress_bar_config(disable=True)
    dataloader = DataLoader(dataset, batch_size=dataset_batch_size)

    ct = datetime.datetime.now()
    save_path = os.path.join(save_path, str(ct))
    # Collecting dataset
    os.makedirs(save_path, exist_ok=True)

    writers = {
        block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
    }

    writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})

    def to_kwargs(kwargs_to_save):
        kwargs = kwargs_to_save.copy()
        seed = kwargs['seed']
        del kwargs['seed']
        kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
        return kwargs

    dataloader_iter = iter(dataloader)
    for num_document, batch in tqdm(enumerate(dataloader)):
        if num_document < start_at:
            continue

        if num_document >= finish_at:
            break

        kwargs_to_save = {
            'prompt': batch['caption'],
            'positions_to_cache': blocks_to_save,
            'save_input': True,
            'save_output': True,
            'num_inference_steps': 1,
            'guidance_scale': 0.0,
            'seed': num_document,
            'output_type': 'pil'
        }

        kwargs = to_kwargs(kwargs_to_save)

        output, cache = pipe.run_with_cache(
            **kwargs
        )

        blocks = cache['input'].keys()
        for block in blocks:
            sample = {
                "__key__": f"sample_{num_document}",
                "output.pth": cache['output'][block],
                "diff.pth": cache['output'][block] - cache['input'][block],
                "gen_args.json": kwargs_to_save
            }

            writers[block].write(sample)
        writers['images'].write({
            "__key__": f"sample_{num_document}",
            "images.npy": np.stack(output.images)
        })

    for block, writer in writers.items():
        writer.close()

if __name__ == '__main__':
    fire.Fire(main)