File size: 2,198 Bytes
8e12b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import random
import re
import torch
from torch.utils.data import Dataset
class PromptDataset(Dataset):
'A simple dataset to prepare the prompts to generate class images on multiple GPUs.'
def __init__(self, opt):
self.opt = opt
self.prompts = opt['prompts']
if isinstance(self.prompts, list):
self.prompts = self.prompts
elif os.path.exists(self.prompts):
# is file
with open(self.prompts, 'r') as fr:
lines = fr.readlines()
lines = [item.strip() for item in lines]
self.prompts = lines
else:
raise ValueError(
'prompts should be a prompt file path or prompt list, please check!'
)
self.prompts = self.replace_placeholder(self.prompts)
self.num_samples_per_prompt = opt['num_samples_per_prompt']
self.prompts_to_generate = [
(p, i) for i in range(1, self.num_samples_per_prompt + 1)
for p in self.prompts
]
self.latent_size = opt['latent_size'] # (4,64,64)
self.share_latent_across_prompt = opt.get('share_latent_across_prompt', True) # (true, false)
def replace_placeholder(self, prompts):
# replace placehold token
replace_mapping = self.opt.get('replace_mapping', {})
new_lines = []
for line in self.prompts:
if len(line.strip()) == 0:
continue
for k, v in replace_mapping.items():
line = line.replace(k, v)
line = line.strip()
line = re.sub(' +', ' ', line)
new_lines.append(line)
return new_lines
def __len__(self):
return len(self.prompts_to_generate)
def __getitem__(self, index):
prompt, indice = self.prompts_to_generate[index]
example = {}
example['prompts'] = prompt
example['indices'] = indice
if self.share_latent_across_prompt:
seed = indice
else:
seed = random.randint(0, 1000)
example['latents'] = torch.randn(self.latent_size, generator=torch.manual_seed(seed))
return example
|