|
""" |
|
Export the predictions of a model for a given dataloader (e.g. ImageFolder). |
|
Use a standalone script with `python3 -m geocalib.scipts.export_predictions dir` |
|
or call from another script. |
|
""" |
|
|
|
import logging |
|
from pathlib import Path |
|
|
|
import h5py |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from siclib.utils.tensor import batch_to_device |
|
from siclib.utils.tools import get_device |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@torch.no_grad() |
|
def export_predictions( |
|
loader, |
|
model, |
|
output_file, |
|
as_half=False, |
|
keys="*", |
|
callback_fn=None, |
|
optional_keys=None, |
|
verbose=True, |
|
): |
|
if optional_keys is None: |
|
optional_keys = [] |
|
|
|
assert keys == "*" or isinstance(keys, (tuple, list)) |
|
Path(output_file).parent.mkdir(exist_ok=True, parents=True) |
|
hfile = h5py.File(str(output_file), "w") |
|
device = get_device() |
|
model = model.to(device).eval() |
|
|
|
if not verbose: |
|
logger.info(f"Exporting predictions to {output_file}") |
|
|
|
for data_ in tqdm(loader, desc="Exporting", total=len(loader), ncols=80, disable=not verbose): |
|
data = batch_to_device(data_, device, non_blocking=True) |
|
pred = model(data) |
|
if callback_fn is not None: |
|
pred = {**callback_fn(pred, data), **pred} |
|
if keys != "*": |
|
if len(set(keys) - set(pred.keys())) > 0: |
|
raise ValueError(f"Missing key {set(keys) - set(pred.keys())}") |
|
pred = {k: v for k, v in pred.items() if k in keys + optional_keys} |
|
|
|
|
|
|
|
for idx in range(len(data["name"])): |
|
pred_ = {k: v[idx].cpu().numpy() for k, v in pred.items()} |
|
|
|
if as_half: |
|
for k in pred_: |
|
dt = pred_[k].dtype |
|
if (dt == np.float32) and (dt != np.float16): |
|
pred_[k] = pred_[k].astype(np.float16) |
|
try: |
|
name = data["name"][idx] |
|
try: |
|
grp = hfile.create_group(name) |
|
except ValueError as e: |
|
raise ValueError(f"Group already exists {name}") from e |
|
|
|
|
|
for k, v in pred_.items(): |
|
grp.create_dataset(k, data=v) |
|
except RuntimeError: |
|
print(f"Failed to export {name}") |
|
continue |
|
|
|
del pred |
|
|
|
hfile.close() |
|
return output_file |
|
|