import torch import gc from pathlib import Path from tqdm.auto import tqdm from dataset import PromptDataset def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2): cur_class_images = len(list(class_images_dir.iterdir())) num_new_images = num_class_images - cur_class_images if num_new_images > 0: sample_dataset = PromptDataset(class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size) for example in tqdm(sample_dataloader, desc="Generating class images"): images = pipeline(example["prompt"]).images for i, image in enumerate(images): image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") del pipeline gc.collect() torch.cuda.empty_cache()