File size: 1,820 Bytes
afc2161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import h5py
import torch
from torch.utils.data import Dataset

from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple
from ellipse_rcnn.utils.conics import bbox_ellipse


class CraterEllipseDataset(Dataset):
    """
    Dataset for crater ellipse detection. Mostly meant as an example in combination with
    https://github.com/wdoppenberg/crater-detection.
    """

    def __init__(self, file_path: str, group: str) -> None:
        self.file_path = file_path
        self.group = group

    def __getitem__(self, idx: torch.Tensor) -> ImageTargetTuple:
        with h5py.File(self.file_path, "r") as dataset:
            image = torch.tensor(dataset[self.group]["images"][idx])

            # The number of instances can vary, hence it was decided to use a separate array with the indices of the
            # instances.
            start_idx = dataset[self.group]["craters/crater_list_idx"][idx]
            end_idx = dataset[self.group]["craters/crater_list_idx"][idx + 1]
            ellipse_matrices = torch.tensor(
                dataset[self.group]["craters/A_craters"][start_idx:end_idx]
            )

        boxes = bbox_ellipse(ellipse_matrices)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        num_objs = len(boxes)

        labels = torch.ones((num_objs,), dtype=torch.int64)
        image_id = torch.tensor([idx])

        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = TargetDict(
            boxes=boxes,
            labels=labels,
            image_id=image_id,
            area=area,
            iscrowd=iscrowd,
            ellipse_matrices=ellipse_matrices,
        )

        return image, target

    def __len__(self) -> int:
        with h5py.File(self.file_path, "r") as f:
            return len(f[self.group]["images"])