Spaces:
Running
Running
from rstor.learning.metrics import compute_metrics, ALL_METRICS | |
import torch | |
import numpy as np | |
from rstor.properties import METRIC_PSNR, METRIC_SSIM | |
from interactive_pipe import interactive, KeyboardControl | |
from typing import Optional | |
def plug_configure_metrics(key_shortcut: Optional[str] = None) -> None: | |
interactive( | |
advanced_metrics=KeyboardControl(False, keydown=key_shortcut) if key_shortcut is not None else (True,) | |
)(configure_metrics) | |
def configure_metrics(advanced_metrics=False, global_params={}) -> None: | |
chosen_metrics = ALL_METRICS if advanced_metrics else [METRIC_PSNR, METRIC_SSIM] | |
global_params["chosen_metrics"] = chosen_metrics | |
def get_metrics(prediction: torch.Tensor, target: torch.Tensor, | |
image_name: str, # use functools.partial to root where you want the title to appear | |
global_params: dict = {}) -> None: | |
if isinstance(prediction, np.ndarray): | |
prediction_ = torch.from_numpy(prediction).permute(-1, 0, 1).float().unsqueeze(0) | |
else: | |
prediction_ = prediction | |
if isinstance(target, np.ndarray): | |
target_ = torch.from_numpy(target).permute(-1, 0, 1).float().unsqueeze(0) | |
else: | |
target_ = target | |
chosen_metrics = global_params.get("chosen_metrics", [METRIC_PSNR]) | |
metrics = compute_metrics(prediction_, target_, chosen_metrics=chosen_metrics) | |
global_params["metrics"] = metrics | |
title = f"{image_name}: " | |
title += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) | |
global_params["__output_styles"][image_name] = {"title": title, "image_name": image_name} | |