Spaces:
Running
Running
File size: 1,625 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from rstor.learning.metrics import compute_metrics, ALL_METRICS
import torch
import numpy as np
from rstor.properties import METRIC_PSNR, METRIC_SSIM
from interactive_pipe import interactive, KeyboardControl
from typing import Optional
def plug_configure_metrics(key_shortcut: Optional[str] = None) -> None:
interactive(
advanced_metrics=KeyboardControl(False, keydown=key_shortcut) if key_shortcut is not None else (True,)
)(configure_metrics)
def configure_metrics(advanced_metrics=False, global_params={}) -> None:
chosen_metrics = ALL_METRICS if advanced_metrics else [METRIC_PSNR, METRIC_SSIM]
global_params["chosen_metrics"] = chosen_metrics
def get_metrics(prediction: torch.Tensor, target: torch.Tensor,
image_name: str, # use functools.partial to root where you want the title to appear
global_params: dict = {}) -> None:
if isinstance(prediction, np.ndarray):
prediction_ = torch.from_numpy(prediction).permute(-1, 0, 1).float().unsqueeze(0)
else:
prediction_ = prediction
if isinstance(target, np.ndarray):
target_ = torch.from_numpy(target).permute(-1, 0, 1).float().unsqueeze(0)
else:
target_ = target
chosen_metrics = global_params.get("chosen_metrics", [METRIC_PSNR])
metrics = compute_metrics(prediction_, target_, chosen_metrics=chosen_metrics)
global_params["metrics"] = metrics
title = f"{image_name}: "
title += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
global_params["__output_styles"][image_name] = {"title": title, "image_name": image_name}
|