Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy | |
import json | |
import zipfile | |
import torch | |
from PIL import Image | |
# from transformers import CLIPImageProcessor | |
from torch.utils.data import Dataset | |
import io | |
from omegaconf import OmegaConf | |
import numpy as np | |
# from torchvision import transforms | |
# from einops import rearrange | |
# import random | |
# import os | |
# from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDIMScheduler | |
# import time | |
# import io | |
# import array | |
# import numpy as np | |
# | |
# from training.triplane import TriPlaneGenerator | |
def to_rgb_image(maybe_rgba: Image.Image): | |
if maybe_rgba.mode == 'RGB': | |
return maybe_rgba | |
elif maybe_rgba.mode == 'RGBA': | |
rgba = maybe_rgba | |
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) | |
img = Image.fromarray(img, 'RGB') | |
img.paste(rgba, mask=rgba.getchannel('A')) | |
return img | |
else: | |
raise ValueError("Unsupported image type.", maybe_rgba.mode) | |
# image(contain style),z,pose,text | |
class TriplaneDataset(Dataset): | |
# image, triplane, ref_feature | |
def __init__(self, json_file, data_base_dir, model_names): | |
super().__init__() | |
self.dict_data_image = json.load(open(json_file)) # {'image_name': pose} | |
self.data_base_dir = data_base_dir | |
self.data_list = list(self.dict_data_image.keys()) | |
self.zip_file_dict = {} | |
config_gan_model = OmegaConf.load(model_names) | |
all_models = config_gan_model['gan_models'].keys() | |
for model_name in all_models: | |
zipfile_path = os.path.join(self.data_base_dir, model_name+'.zip') | |
zipfile_load = zipfile.ZipFile(zipfile_path) | |
self.zip_file_dict[model_name] = zipfile_load | |
def getdata(self, idx): | |
# need z and expression and model name | |
# image:"seed0035.png" | |
# data_each_dict = { | |
# 'vert_dir': vert_dir, | |
# 'z_dir': z_dir, | |
# 'pose_dir': pose_dir, | |
# 'img_dir': img_dir, | |
# 'model_name': model_name | |
# } | |
data_name = self.data_list[idx] | |
data_model_name = self.dict_data_image[data_name]['model_name'] | |
zipfile_loaded = self.zip_file_dict[data_model_name] | |
# zipfile_path = os.path.join(self.data_base_dir, data_model_name) | |
# zipfile_loaded = zipfile.ZipFile(zipfile_path) | |
with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f: | |
buffer = io.BytesIO(f.read()) | |
data_z = torch.load(buffer) | |
buffer.close() | |
f.close() | |
with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as ff: | |
buffer_v = io.BytesIO(ff.read()) | |
data_vert = torch.load(buffer_v) | |
buffer_v.close() | |
ff.close() | |
# raw_image = to_rgb_image(Image.open(f)) | |
# | |
# data_model_name = self.dict_data_image[data_name]['model_name'] | |
# data_z_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['z_dir']) | |
# data_vert_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['vert_dir']) | |
# data_z = torch.load(data_z_dir) | |
# data_vert = torch.load(data_vert_dir) | |
return { | |
"data_z": data_z, | |
"data_vert": data_vert, | |
"data_model_name": data_model_name | |
} | |
def __getitem__(self, idx): | |
for _ in range(20): | |
try: | |
return self.getdata(idx) | |
except Exception as e: | |
print(f"Error details: {str(e)}") | |
idx = np.random.randint(len(self)) | |
raise RuntimeError('Too many bad data.') | |
def __len__(self): | |
return len(self.data_list) | |
# for zip files | |