File size: 6,885 Bytes
e4bf056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import cv2
import json
import itertools
import numpy as np
import os.path as osp
from collections import deque

from dust3r.utils.image import imread_cv2
from .base_many_view_dataset import BaseManyViewDataset


class Co3d(BaseManyViewDataset):
    def __init__(self, mask_bg=True, use_comb=True,
                 scene_class=None, scene_id=None,
                 num_seq=100, num_frames=5, 
                 min_thresh=10, max_thresh=100,
                 full_video=False, lb=0, ub=30,
                 kf_every=1, *args, ROOT, **kwargs):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)

        assert mask_bg in (True, False, 'rand')
        self.mask_bg = mask_bg
    
        self.num_seq = num_seq
        self.num_frames = num_frames
        self.max_thresh = max_thresh
        self.min_thresh = min_thresh

        self.full_video = full_video
        self.kf_every = kf_every
        self.use_comb = use_comb

        self.scenes, self.scene_list = self.load_scene(scene_class, scene_id)

        self.combinations, self.num_seq = self.get_combinations(use_comb, lb, ub)
        self.invalidate = {scene: {} for scene in self.scene_list}
    

    def get_combinations(self, use_comb, lb, ub):

        if use_comb and not self.full_video:
            print('Using combinations')
            combinations = list(itertools.combinations(range(100), self.num_frames))
            combinations = [combo for combo in combinations if all(lb < abs(x-y) <= ub and abs(x-y) % 5 == 0 for x, y in zip(combo, combo[1:]))]
            num_seq = len(combinations)
            print('Number of sequences:', num_seq)
        else:
            combinations = None
            num_seq = self.num_seq
        
        return combinations, num_seq

    

    def load_scene(self, scene_class=None, scene_id=None):
        print('Loading scenes')
        with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
            scenes = json.load(f)
            
            if scene_class is not None:
                scenes = {k: v for k, v in scenes.items() if k == scene_class}
            else:
                scenes = {k: v for k, v in scenes.items() if len(v) > 0} # k is class (apple), v is corresponding list
            
            if scene_id is not None:
                scenes = {(k, k2): v2 for k, v in scenes.items() for k2, v2 in v.items() if k2 == scene_id}
            else:
                scenes = {(k, k2): v2 for k, v in scenes.items() 
                            for k2, v2 in v.items()} # k is class (apple), k2 is instance (110_13051_23361), v2 is list of image idx
        scene_list = list(scenes.keys())
        
        return scenes, scene_list
    
    def __len__(self):

        return len(self.scene_list) * self.num_seq
    
    def _get_views(self, idx, resolution, rng, attempts=0):
        obj, instance = self.scene_list[idx // self.num_seq]
        image_pool = self.scenes[obj, instance]

        if self.use_comb and not self.full_video:
            frame_idx = self.combinations[idx % len(self.combinations)]

            last = len(image_pool)-1
            imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in frame_idx]

        else:
            img_idx = range(0, len(image_pool))
            imgs_idxs = self.sample_frames(img_idx, rng)
        

        if resolution not in self.invalidate[obj, instance]:  # flag invalid images
            self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
        
        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
       
        imgs_idxs = deque(imgs_idxs)

        max_depth_min = 1e8
        max_depth_max = 0.      
        max_depth_first = None  
        

        views = []

        while len(imgs_idxs) > 0: 
            im_idx = imgs_idxs.popleft()

            if self.invalidate[obj, instance][resolution][im_idx]:
                # search for a valid image
                random_direction = 2 * rng.choice(2) - 1
                for offset in range(1, len(image_pool)):
                    tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
                    if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
                        im_idx = tentative_im_idx
                        break

            view_idx = image_pool[im_idx]

            impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06d}.jpg')

            # load camera params
            input_metadata = np.load(impath.replace('jpg', 'npz'))
            camera_pose = input_metadata['camera_pose'].astype(np.float32)
            intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)

            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
            depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])

            if mask_bg:
                # load object mask
                maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06d}.png')
                maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
                maskmap = (maskmap / 255.0) > 0.1

                # update the depthmap with mask
                depthmap *= maskmap
                            
            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)

            num_valid = (depthmap > 0.0).sum()
            if num_valid == 0:
                # problem, invalidate image and retry
                self.invalidate[obj, instance][resolution][im_idx] = True
                imgs_idxs.appendleft(im_idx)
                continue

            if input_metadata['maximum_depth'] > max_depth_max:
                max_depth_max = input_metadata['maximum_depth']
            
            if input_metadata['maximum_depth'] < max_depth_min:
                max_depth_min = input_metadata['maximum_depth']
            
            if max_depth_first is None:
                max_depth_first = input_metadata['maximum_depth']
            
            
            
            views.append(dict(
                img=rgb_image,
                depthmap=depthmap,
                camera_pose=camera_pose,
                camera_intrinsics=intrinsics,
                dataset='Co3d_v2',
                label=osp.join(obj, instance),
                instance=osp.split(impath)[1]
            ))

        if max_depth_max / max_depth_min > 100. or max_depth_max / max_depth_first > 10.:
            new_idx = rng.integers(0, self.__len__()-1)
            return self._get_views(new_idx, resolution, rng)
        
        return views