lehduong commited on
Commit
c2900a3
·
verified ·
1 Parent(s): 8e65b97

Delete dataset/multitask/multiview.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/multitask/multiview.py +0 -277
dataset/multitask/multiview.py DELETED
@@ -1,277 +0,0 @@
1
- import os
2
- import json
3
- import random
4
- from PIL import Image
5
- import torch
6
- from typing import List, Tuple, Union
7
- from torch.utils.data import Dataset
8
- from torchvision import transforms
9
- import torchvision.transforms as T
10
- from onediffusion.dataset.utils import *
11
- import glob
12
-
13
- from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
14
- from onediffusion.dataset.transforms import CenterCropResizeImage
15
- from pytorch3d.renderer import PerspectiveCameras
16
-
17
- import numpy as np
18
-
19
- def _cameras_from_opencv_projection(
20
- R: torch.Tensor,
21
- tvec: torch.Tensor,
22
- camera_matrix: torch.Tensor,
23
- image_size: torch.Tensor,
24
- do_normalize_cameras,
25
- normalize_scale,
26
- ) -> PerspectiveCameras:
27
- focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
28
- principal_point = camera_matrix[:, :2, 2]
29
-
30
- # Retype the image_size correctly and flip to width, height.
31
- image_size_wh = image_size.to(R).flip(dims=(1,))
32
-
33
- # Screen to NDC conversion:
34
- # For non square images, we scale the points such that smallest side
35
- # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
36
- # This convention is consistent with the PyTorch3D renderer, as well as
37
- # the transformation function `get_ndc_to_screen_transform`.
38
- scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
39
- scale = scale.expand(-1, 2)
40
- c0 = image_size_wh / 2.0
41
-
42
- # Get the PyTorch3D focal length and principal point.
43
- focal_pytorch3d = focal_length / scale
44
- p0_pytorch3d = -(principal_point - c0) / scale
45
-
46
- # For R, T we flip x, y axes (opencv screen space has an opposite
47
- # orientation of screen axes).
48
- # We also transpose R (opencv multiplies points from the opposite=left side).
49
- R_pytorch3d = R.clone().permute(0, 2, 1)
50
- T_pytorch3d = tvec.clone()
51
- R_pytorch3d[:, :, :2] *= -1
52
- T_pytorch3d[:, :2] *= -1
53
-
54
- cams = PerspectiveCameras(
55
- R=R_pytorch3d,
56
- T=T_pytorch3d,
57
- focal_length=focal_pytorch3d,
58
- principal_point=p0_pytorch3d,
59
- image_size=image_size,
60
- device=R.device,
61
- )
62
-
63
- if do_normalize_cameras:
64
- cams, _ = normalize_cameras(cams, scale=normalize_scale)
65
-
66
- cams = first_camera_transform(cams, rotation_only=False)
67
- return cams
68
-
69
- def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
70
- cameras = _cameras_from_opencv_projection(
71
- R=Rs,
72
- tvec=Ts,
73
- camera_matrix=Ks,
74
- image_size=sizes,
75
- do_normalize_cameras=do_normalize_cameras,
76
- normalize_scale=normalize_scale
77
- )
78
-
79
- rays_embedding = cameras_to_rays(
80
- cameras=cameras,
81
- num_patches_x=target_size,
82
- num_patches_y=target_size,
83
- crop_parameters=None,
84
- use_plucker=use_plucker
85
- )
86
-
87
- return rays_embedding.rays
88
-
89
- def convert_rgba_to_rgb_white_bg(image):
90
- """Convert RGBA image to RGB with white background"""
91
- if image.mode == 'RGBA':
92
- # Create a white background
93
- background = Image.new('RGBA', image.size, (255, 255, 255, 255))
94
- # Composite the image onto the white background
95
- return Image.alpha_composite(background, image).convert('RGB')
96
- return image.convert('RGB')
97
-
98
- class MultiviewDataset(Dataset):
99
- def __init__(
100
- self,
101
- scene_folders: str,
102
- samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range
103
- transform=None,
104
- caption_keys: Union[str, List] = "caption",
105
- multiscale=False,
106
- aspect_ratio_type=ASPECT_RATIO_512,
107
- c2w_scaling=1.7,
108
- default_max_distance=1, # default max distance from all camera of a scene ,
109
- do_normalize=True, # whether normalize translation of c2w with max_distance
110
- swap_xz=False, # whether swap x and z axis of 3D scenes
111
- valid_paths: str = "",
112
- frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
113
- ):
114
- if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
115
- samples_per_set = (samples_per_set, samples_per_set)
116
- self.samples_range = samples_per_set # Tuple of (min_samples, max_samples)
117
- self.transform = transform
118
- self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
119
- self.aspect_ratio = aspect_ratio_type
120
- self.scene_folders = sorted(glob.glob(scene_folders))
121
- # filter out scene folders that do not have transforms.json
122
- self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))
123
-
124
- # if valid_paths.txt exists, only use paths in that file
125
- if os.path.exists(valid_paths):
126
- with open(valid_paths, 'r') as f:
127
- valid_scene_folders = f.read().splitlines()
128
- self.scene_folders = sorted(valid_scene_folders)
129
-
130
- self.c2w_scaling = c2w_scaling
131
- self.do_normalize = do_normalize
132
- self.default_max_distance = default_max_distance
133
- self.swap_xz = swap_xz
134
- self.frame_sliding_windows = frame_sliding_windows
135
-
136
- if multiscale:
137
- assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
138
- if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
139
- self.interpolate_model = T.InterpolationMode.LANCZOS
140
- self.ratio_index = {}
141
- self.ratio_nums = {}
142
- for k, v in self.aspect_ratio.items():
143
- self.ratio_index[float(k)] = [] # used for self.getitem
144
- self.ratio_nums[float(k)] = 0 # used for batch-sampler
145
-
146
- def __len__(self):
147
- return len(self.scene_folders)
148
-
149
- def __getitem__(self, idx):
150
- try:
151
- scene_path = self.scene_folders[idx]
152
-
153
- if os.path.exists(os.path.join(scene_path, "images")):
154
- image_folder = os.path.join(scene_path, "images")
155
- downscale_factor = 1
156
- elif os.path.exists(os.path.join(scene_path, "images_4")):
157
- image_folder = os.path.join(scene_path, "images_4")
158
- downscale_factor = 1 / 4
159
- elif os.path.exists(os.path.join(scene_path, "images_8")):
160
- image_folder = os.path.join(scene_path, "images_8")
161
- downscale_factor = 1 / 8
162
- else:
163
- raise NotImplementedError
164
-
165
- json_path = os.path.join(scene_path, "transforms.json")
166
- caption_path = os.path.join(scene_path, "caption.json")
167
- image_files = os.listdir(image_folder)
168
-
169
- with open(json_path, 'r') as f:
170
- json_data = json.load(f)
171
- height, width = json_data['h'], json_data['w']
172
-
173
- dh, dw = int(height * downscale_factor), int(width * downscale_factor)
174
- fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
175
- cx = dw // 2
176
- cy = dh // 2
177
-
178
- frame_list = json_data['frames']
179
-
180
- # Randomly select number of samples
181
-
182
- samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
183
-
184
- # uniformly for all scenes
185
- if self.frame_sliding_windows is None:
186
- selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
187
- # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
188
- else:
189
- # Determine the starting index of the sliding window
190
- if len(frame_list) <= self.frame_sliding_windows:
191
- # If the frame list is smaller than or equal to X, use the entire list
192
- window_start = 0
193
- window_end = len(frame_list)
194
- else:
195
- # Randomly select a starting point for the window
196
- window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
197
- window_end = window_start + self.frame_sliding_windows
198
-
199
- # Get the indices within the sliding window
200
- window_indices = list(range(window_start, window_end))
201
-
202
- # Randomly sample indices from the window
203
- selected_indices = random.sample(window_indices, samples_per_set)
204
-
205
- image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
206
- image_paths = [os.path.join(image_folder, file) for file in image_files]
207
-
208
- # Load images and convert RGBA to RGB with white background
209
- images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
210
-
211
- if self.transform:
212
- images = [self.transform(image) for image in images]
213
- else:
214
- closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
215
- closest_size = tuple(map(int, closest_size))
216
- transform = T.Compose([
217
- T.ToTensor(),
218
- CenterCropResizeImage(closest_size),
219
- T.Normalize([.5], [.5]),
220
- ])
221
- images = [transform(image) for image in images]
222
- images = torch.stack(images)
223
-
224
- c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
225
- c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
226
- # max_distance = json_data.get('max_distance', self.default_max_distance)
227
- # if 'max_distance' not in json_data.keys():
228
- # print(f"not found `max_distance` in json path: {json_path}")
229
-
230
- if self.swap_xz:
231
- swap_xz = torch.tensor([[[0, 0, 1., 0],
232
- [0, 1., 0, 0],
233
- [-1., 0, 0, 0],
234
- [0, 0, 0, 1.]]])
235
- c2ws = swap_xz @ c2ws
236
-
237
- # OPENGL to OPENCV
238
- c2ws[:, 0:3, 1:3] *= -1
239
- c2ws = c2ws[:, [1, 0, 2, 3], :]
240
- c2ws[:, 2, :] *= -1
241
-
242
- w2cs = torch.inverse(c2ws)
243
- K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
244
- Rs = w2cs[:, :3, :3]
245
- Ts = w2cs[:, :3, 3]
246
- sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
247
-
248
- # get ray embedding and padding last dimension to 16 (num channels of VAE)
249
- # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
250
- rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
251
- rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
252
- # padding = (0, 10) # pad the last dimension to 16
253
- # rays = torch.nn.functional.pad(rays, padding, "constant", 0)
254
- rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
255
-
256
- if os.path.exists(caption_path):
257
- with open(caption_path, 'r') as f:
258
- caption_key = random.choice(self.caption_keys)
259
- caption = json.load(f).get(caption_key, "")
260
- else:
261
- caption = ""
262
-
263
- caption = "[[multiview]] " + caption if caption else "[[multiview]]"
264
-
265
- return {
266
- 'pixel_values': images,
267
- 'rays': rays,
268
- 'aspect_ratio': closest_ratio,
269
- 'caption': caption,
270
- 'height': dh,
271
- 'width': dw,
272
- # 'origins': rays_od[..., :3],
273
- # 'dirs': rays_od[..., 3:6]
274
- }
275
- except Exception as e:
276
- return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))
277
-