ortha / mixofshow /data /prompt_dataset.py
ujin-song's picture
upload mixofshow and orthogonal_mats folder
8e12b4e verified
raw
history blame
2.2 kB
import os
import random
import re
import torch
from torch.utils.data import Dataset
class PromptDataset(Dataset):
'A simple dataset to prepare the prompts to generate class images on multiple GPUs.'
def __init__(self, opt):
self.opt = opt
self.prompts = opt['prompts']
if isinstance(self.prompts, list):
self.prompts = self.prompts
elif os.path.exists(self.prompts):
# is file
with open(self.prompts, 'r') as fr:
lines = fr.readlines()
lines = [item.strip() for item in lines]
self.prompts = lines
else:
raise ValueError(
'prompts should be a prompt file path or prompt list, please check!'
)
self.prompts = self.replace_placeholder(self.prompts)
self.num_samples_per_prompt = opt['num_samples_per_prompt']
self.prompts_to_generate = [
(p, i) for i in range(1, self.num_samples_per_prompt + 1)
for p in self.prompts
]
self.latent_size = opt['latent_size'] # (4,64,64)
self.share_latent_across_prompt = opt.get('share_latent_across_prompt', True) # (true, false)
def replace_placeholder(self, prompts):
# replace placehold token
replace_mapping = self.opt.get('replace_mapping', {})
new_lines = []
for line in self.prompts:
if len(line.strip()) == 0:
continue
for k, v in replace_mapping.items():
line = line.replace(k, v)
line = line.strip()
line = re.sub(' +', ' ', line)
new_lines.append(line)
return new_lines
def __len__(self):
return len(self.prompts_to_generate)
def __getitem__(self, index):
prompt, indice = self.prompts_to_generate[index]
example = {}
example['prompts'] = prompt
example['indices'] = indice
if self.share_latent_across_prompt:
seed = indice
else:
seed = random.randint(0, 1000)
example['latents'] = torch.randn(self.latent_size, generator=torch.manual_seed(seed))
return example