File size: 1,979 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -*- coding: utf-8 -*-
# @Author  : xuelun

import os
import cv2
import torch

from os.path import join
from torch.utils.data import Dataset


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


class WALKDataset(Dataset):

    def __init__(self, data_root, vs, ids, checkpoint, opt):
        super().__init__()

        self.vs = vs
        self.ids = ids[checkpoint:]

        old_image_root = join(data_root, 'image_1080p', opt.scene_name)
        new_image_root = join(data_root, 'image_1080p', opt.scene_name.strip())
        if not os.path.exists(new_image_root):
            if os.path.exists(old_image_root):
                os.rename(old_image_root, new_image_root)
            else:
                os.makedirs(new_image_root, exist_ok=True)
        self.image_root = new_image_root

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        idx0, idx1 = self.ids[idx]

        # get image
        img_path0 = join(self.image_root, '{}.png'.format(idx0))
        if not os.path.exists(img_path0):
            rgb0 = self.vs[idx0]
            rgb0_is_good = False
        else:
            rgb0 = cv2.imread(img_path0)
            rgb0_is_good = True
            if rgb0 is None:
                rgb0 = self.vs[idx0]
                rgb0_is_good = False

        img_path1 = join(self.image_root, '{}.png'.format(idx1))
        if not os.path.exists(img_path1):
            rgb1 = self.vs[idx1]
            rgb1_is_good = False
        else:
            rgb1 = cv2.imread(img_path1)
            rgb1_is_good = True
            if rgb1 is None:
                rgb1 = self.vs[idx1]
                rgb1_is_good = False

        return {'idx': idx, 'idx0': idx0, 'idx1': idx1, 'rgb0': rgb0, 'rgb1': rgb1,
                'img_path0': img_path0, 'img_path1': img_path1,
                'rgb0_is_good':rgb0_is_good, 'rgb1_is_good': rgb1_is_good}