|
import h5py |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple |
|
from ellipse_rcnn.utils.conics import bbox_ellipse |
|
|
|
|
|
class CraterEllipseDataset(Dataset): |
|
""" |
|
Dataset for crater ellipse detection. Mostly meant as an example in combination with |
|
https://github.com/wdoppenberg/crater-detection. |
|
""" |
|
|
|
def __init__(self, file_path: str, group: str) -> None: |
|
self.file_path = file_path |
|
self.group = group |
|
|
|
def __getitem__(self, idx: torch.Tensor) -> ImageTargetTuple: |
|
with h5py.File(self.file_path, "r") as dataset: |
|
image = torch.tensor(dataset[self.group]["images"][idx]) |
|
|
|
|
|
|
|
start_idx = dataset[self.group]["craters/crater_list_idx"][idx] |
|
end_idx = dataset[self.group]["craters/crater_list_idx"][idx + 1] |
|
ellipse_matrices = torch.tensor( |
|
dataset[self.group]["craters/A_craters"][start_idx:end_idx] |
|
) |
|
|
|
boxes = bbox_ellipse(ellipse_matrices) |
|
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) |
|
|
|
num_objs = len(boxes) |
|
|
|
labels = torch.ones((num_objs,), dtype=torch.int64) |
|
image_id = torch.tensor([idx]) |
|
|
|
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) |
|
|
|
target = TargetDict( |
|
boxes=boxes, |
|
labels=labels, |
|
image_id=image_id, |
|
area=area, |
|
iscrowd=iscrowd, |
|
ellipse_matrices=ellipse_matrices, |
|
) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
with h5py.File(self.file_path, "r") as f: |
|
return len(f[self.group]["images"]) |
|
|