|
import string |
|
|
|
import h5py |
|
import torch |
|
|
|
from siclib.datasets.base_dataset import collate |
|
from siclib.models.base_model import BaseModel |
|
from siclib.settings import DATA_PATH |
|
from siclib.utils.tensor import batch_to_device |
|
|
|
|
|
|
|
|
|
|
|
def pad_line_features(pred, seq_l: int = None): |
|
raise NotImplementedError |
|
|
|
|
|
def recursive_load(grp, pkeys): |
|
return { |
|
k: ( |
|
torch.from_numpy(grp[k].__array__()) |
|
if isinstance(grp[k], h5py.Dataset) |
|
else recursive_load(grp[k], list(grp.keys())) |
|
) |
|
for k in pkeys |
|
} |
|
|
|
|
|
class CacheLoader(BaseModel): |
|
default_conf = { |
|
"path": "???", |
|
"data_keys": None, |
|
"device": None, |
|
"trainable": False, |
|
"add_data_path": True, |
|
"collate": True, |
|
"scale": ["keypoints"], |
|
"padding_fn": None, |
|
"padding_length": None, |
|
"numeric_type": "float32", |
|
} |
|
|
|
required_data_keys = ["name"] |
|
|
|
def _init(self, conf): |
|
self.hfiles = {} |
|
self.padding_fn = conf.padding_fn |
|
if self.padding_fn is not None: |
|
self.padding_fn = eval(self.padding_fn) |
|
self.numeric_dtype = { |
|
None: None, |
|
"float16": torch.float16, |
|
"float32": torch.float32, |
|
"float64": torch.float64, |
|
}[conf.numeric_type] |
|
|
|
def _forward(self, data): |
|
preds = [] |
|
device = self.conf.device |
|
if not device: |
|
if devices := {v.device for v in data.values() if isinstance(v, torch.Tensor)}: |
|
assert len(devices) == 1 |
|
device = devices.pop() |
|
|
|
else: |
|
device = "cpu" |
|
|
|
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]] |
|
for i, name in enumerate(data["name"]): |
|
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names}) |
|
if self.conf.add_data_path: |
|
fpath = DATA_PATH / fpath |
|
hfile = h5py.File(str(fpath), "r") |
|
grp = hfile[name] |
|
pkeys = self.conf.data_keys if self.conf.data_keys is not None else grp.keys() |
|
pred = recursive_load(grp, pkeys) |
|
if self.numeric_dtype is not None: |
|
pred = { |
|
k: ( |
|
v |
|
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) |
|
else v.to(dtype=self.numeric_dtype) |
|
) |
|
for k, v in pred.items() |
|
} |
|
pred = batch_to_device(pred, device) |
|
for k, v in pred.items(): |
|
for pattern in self.conf.scale: |
|
if k.startswith(pattern): |
|
view_idx = k.replace(pattern, "") |
|
scales = ( |
|
data["scales"] |
|
if len(view_idx) == 0 |
|
else data[f"view{view_idx}"]["scales"] |
|
) |
|
pred[k] = pred[k] * scales[i] |
|
|
|
if self.padding_fn is not None: |
|
pred = self.padding_fn(pred, self.conf.padding_length) |
|
preds.append(pred) |
|
hfile.close() |
|
if self.conf.collate: |
|
return batch_to_device(collate(preds), device) |
|
assert len(preds) == 1 |
|
return batch_to_device(preds[0], device) |
|
|
|
def loss(self, pred, data): |
|
raise NotImplementedError |
|
|