File size: 1,625 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from rstor.learning.metrics import compute_metrics, ALL_METRICS
import torch
import numpy as np
from rstor.properties import METRIC_PSNR, METRIC_SSIM
from interactive_pipe import interactive, KeyboardControl
from typing import Optional


def plug_configure_metrics(key_shortcut: Optional[str] = None) -> None:
    interactive(
        advanced_metrics=KeyboardControl(False, keydown=key_shortcut) if key_shortcut is not None else (True,)
    )(configure_metrics)


def configure_metrics(advanced_metrics=False, global_params={}) -> None:
    chosen_metrics = ALL_METRICS if advanced_metrics else [METRIC_PSNR, METRIC_SSIM]
    global_params["chosen_metrics"] = chosen_metrics


def get_metrics(prediction: torch.Tensor, target: torch.Tensor,
                image_name: str,  # use functools.partial to root where you want the title to appear
                global_params: dict = {}) -> None:
    if isinstance(prediction, np.ndarray):
        prediction_ = torch.from_numpy(prediction).permute(-1, 0, 1).float().unsqueeze(0)
    else:
        prediction_ = prediction
    if isinstance(target, np.ndarray):
        target_ = torch.from_numpy(target).permute(-1, 0, 1).float().unsqueeze(0)
    else:
        target_ = target
    chosen_metrics = global_params.get("chosen_metrics", [METRIC_PSNR])
    metrics = compute_metrics(prediction_, target_, chosen_metrics=chosen_metrics)
    global_params["metrics"] = metrics
    title = f"{image_name}: "
    title += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
    global_params["__output_styles"][image_name] = {"title": title, "image_name": image_name}