File size: 912 Bytes
c09bcc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import gc
from pathlib import Path
from tqdm.auto import tqdm
from dataset import PromptDataset
def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2):
cur_class_images = len(list(class_images_dir.iterdir()))
num_new_images = num_class_images - cur_class_images
if num_new_images > 0:
sample_dataset = PromptDataset(class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)
for example in tqdm(sample_dataloader, desc="Generating class images"):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
del pipeline
gc.collect()
torch.cuda.empty_cache()
|