|
import torch
|
|
import gc
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
from dataset import PromptDataset
|
|
|
|
def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2):
|
|
cur_class_images = len(list(class_images_dir.iterdir()))
|
|
num_new_images = num_class_images - cur_class_images
|
|
|
|
if num_new_images > 0:
|
|
sample_dataset = PromptDataset(class_prompt, num_new_images)
|
|
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)
|
|
|
|
for example in tqdm(sample_dataloader, desc="Generating class images"):
|
|
images = pipeline(example["prompt"]).images
|
|
for i, image in enumerate(images):
|
|
image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
|
|
|
|
del pipeline
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|