File size: 912 Bytes
c09bcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import gc
from pathlib import Path
from tqdm.auto import tqdm
from dataset import PromptDataset

def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2):
    cur_class_images = len(list(class_images_dir.iterdir()))
    num_new_images = num_class_images - cur_class_images

    if num_new_images > 0:
        sample_dataset = PromptDataset(class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)

        for example in tqdm(sample_dataloader, desc="Generating class images"):
            images = pipeline(example["prompt"]).images
            for i, image in enumerate(images):
                image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")

        del pipeline
        gc.collect()
        torch.cuda.empty_cache()