import json import os import random import re from pathlib import Path from PIL import Image from torch.utils.data import Dataset from mixofshow.data.pil_transform import PairCompose, build_transform class LoraDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__(self, opt): self.opt = opt self.instance_images_path = [] with open(opt['concept_list'], 'r') as f: concept_list = json.load(f) replace_mapping = opt.get('replace_mapping', {}) use_caption = opt.get('use_caption', False) use_mask = opt.get('use_mask', False) for concept in concept_list: instance_prompt = concept['instance_prompt'] caption_dir = concept.get('caption_dir') mask_dir = concept.get('mask_dir') instance_prompt = self.process_text(instance_prompt, replace_mapping) inst_img_path = [] for x in Path(concept['instance_data_dir']).iterdir(): if x.is_file() and x.name != '.DS_Store': basename = os.path.splitext(os.path.basename(x))[0] caption_path = os.path.join(caption_dir, f'{basename}.txt') if caption_dir is not None else None if use_caption and caption_path is not None and os.path.exists(caption_path): with open(caption_path, 'r') as fr: line = fr.readlines()[0] instance_prompt_image = self.process_text(line, replace_mapping) else: instance_prompt_image = instance_prompt if use_mask and mask_dir is not None: mask_path = os.path.join(mask_dir, f'{basename}.png') else: mask_path = None inst_img_path.append((x, instance_prompt_image, mask_path)) self.instance_images_path.extend(inst_img_path) random.shuffle(self.instance_images_path) self.num_instance_images = len(self.instance_images_path) self.instance_transform = PairCompose([ build_transform(transform_opt) for transform_opt in opt['instance_transform'] ]) def process_text(self, instance_prompt, replace_mapping): for k, v in replace_mapping.items(): instance_prompt = instance_prompt.replace(k, v) instance_prompt = instance_prompt.strip() instance_prompt = re.sub(' +', ' ', instance_prompt) return instance_prompt def __len__(self): return self.num_instance_images * self.opt['dataset_enlarge_ratio'] def __getitem__(self, index): example = {} instance_image, instance_prompt, instance_mask = self.instance_images_path[index % self.num_instance_images] instance_image = Image.open(instance_image).convert('RGB') extra_args = {'prompts': instance_prompt} if instance_mask is not None: instance_mask = Image.open(instance_mask).convert('L') extra_args.update({'mask': instance_mask}) instance_image, extra_args = self.instance_transform(instance_image, **extra_args) example['images'] = instance_image if 'mask' in extra_args: example['masks'] = extra_args['mask'] example['masks'] = example['masks'].unsqueeze(0) else: pass if 'img_mask' in extra_args: example['img_masks'] = extra_args['img_mask'] example['img_masks'] = example['img_masks'].unsqueeze(0) else: raise NotImplementedError example['prompts'] = extra_args['prompts'] return example