Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import io | |
import tarfile | |
import torch | |
import webdataset as wds | |
import numpy as np | |
from tqdm import tqdm | |
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline | |
import datetime | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader | |
import diffusers | |
import fire | |
def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50): | |
blocks_to_save = [ | |
'unet.down_blocks.2.attentions.1', | |
'unet.mid_block.attentions.0', | |
'unet.up_blocks.0.attentions.0', | |
'unet.up_blocks.0.attentions.1', | |
] | |
# Initialization | |
dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42) | |
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo') | |
pipe.to('cuda') | |
pipe.set_progress_bar_config(disable=True) | |
dataloader = DataLoader(dataset, batch_size=dataset_batch_size) | |
ct = datetime.datetime.now() | |
save_path = os.path.join(save_path, str(ct)) | |
# Collecting dataset | |
os.makedirs(save_path, exist_ok=True) | |
writers = { | |
block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save | |
} | |
writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')}) | |
def to_kwargs(kwargs_to_save): | |
kwargs = kwargs_to_save.copy() | |
seed = kwargs['seed'] | |
del kwargs['seed'] | |
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document) | |
return kwargs | |
dataloader_iter = iter(dataloader) | |
for num_document, batch in tqdm(enumerate(dataloader)): | |
if num_document < start_at: | |
continue | |
if num_document >= finish_at: | |
break | |
kwargs_to_save = { | |
'prompt': batch['caption'], | |
'positions_to_cache': blocks_to_save, | |
'save_input': True, | |
'save_output': True, | |
'num_inference_steps': 1, | |
'guidance_scale': 0.0, | |
'seed': num_document, | |
'output_type': 'pil' | |
} | |
kwargs = to_kwargs(kwargs_to_save) | |
output, cache = pipe.run_with_cache( | |
**kwargs | |
) | |
blocks = cache['input'].keys() | |
for block in blocks: | |
sample = { | |
"__key__": f"sample_{num_document}", | |
"output.pth": cache['output'][block], | |
"diff.pth": cache['output'][block] - cache['input'][block], | |
"gen_args.json": kwargs_to_save | |
} | |
writers[block].write(sample) | |
writers['images'].write({ | |
"__key__": f"sample_{num_document}", | |
"images.npy": np.stack(output.images) | |
}) | |
for block, writer in writers.items(): | |
writer.close() | |
if __name__ == '__main__': | |
fire.Fire(main) |