GenAI_Text_to_Image / generate_class_images.py
kothariyashhh's picture
Upload 72 files
c09bcc2 verified
raw
history blame contribute delete
912 Bytes
import torch
import gc
from pathlib import Path
from tqdm.auto import tqdm
from dataset import PromptDataset
def generate_class_images(pipeline, class_prompt, num_class_images, class_images_dir, sample_batch_size=2):
cur_class_images = len(list(class_images_dir.iterdir()))
num_new_images = num_class_images - cur_class_images
if num_new_images > 0:
sample_dataset = PromptDataset(class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)
for example in tqdm(sample_dataloader, desc="Generating class images"):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
del pipeline
gc.collect()
torch.cuda.empty_cache()