GeoCalib / siclib /models /cache_loader.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
3.77 kB
import string
import h5py
import torch
from siclib.datasets.base_dataset import collate
from siclib.models.base_model import BaseModel
from siclib.settings import DATA_PATH
from siclib.utils.tensor import batch_to_device
# flake8: noqa
# mypy: ignore-errors
def pad_line_features(pred, seq_l: int = None):
raise NotImplementedError
def recursive_load(grp, pkeys):
return {
k: (
torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
)
for k in pkeys
}
class CacheLoader(BaseModel):
default_conf = {
"path": "???", # can be a format string like exports/{scene}/
"data_keys": None, # load all keys
"device": None, # load to same device as data
"trainable": False,
"add_data_path": True,
"collate": True,
"scale": ["keypoints"],
"padding_fn": None,
"padding_length": None, # required for batching!
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
}
required_data_keys = ["name"] # we need an identifier
def _init(self, conf):
self.hfiles = {}
self.padding_fn = conf.padding_fn
if self.padding_fn is not None:
self.padding_fn = eval(self.padding_fn)
self.numeric_dtype = {
None: None,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}[conf.numeric_type]
def _forward(self, data): # sourcery skip: low-code-quality
preds = []
device = self.conf.device
if not device:
if devices := {v.device for v in data.values() if isinstance(v, torch.Tensor)}:
assert len(devices) == 1
device = devices.pop()
else:
device = "cpu"
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
for i, name in enumerate(data["name"]):
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
if self.conf.add_data_path:
fpath = DATA_PATH / fpath
hfile = h5py.File(str(fpath), "r")
grp = hfile[name]
pkeys = self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: (
v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
for k, v in pred.items():
for pattern in self.conf.scale:
if k.startswith(pattern):
view_idx = k.replace(pattern, "")
scales = (
data["scales"]
if len(view_idx) == 0
else data[f"view{view_idx}"]["scales"]
)
pred[k] = pred[k] * scales[i]
# use this function to fix number of keypoints etc.
if self.padding_fn is not None:
pred = self.padding_fn(pred, self.conf.padding_length)
preds.append(pred)
hfile.close()
if self.conf.collate:
return batch_to_device(collate(preds), device)
assert len(preds) == 1
return batch_to_device(preds[0], device)
def loss(self, pred, data):
raise NotImplementedError