Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
import numpy as np | |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
from functools import cache | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
class LabelWeightedBCELoss(nn.Module): | |
""" | |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution. | |
Allows for the weighing of each probability distribution wrt loss. | |
""" | |
def __init__(self, label_weights: torch.Tensor, reduction="mean"): | |
super().__init__() | |
self.label_weights = label_weights | |
match reduction: | |
case "mean": | |
self.reduction = torch.mean | |
case "sum": | |
self.reduction = torch.sum | |
def _log(self, x: torch.Tensor) -> torch.Tensor: | |
return torch.clamp_min(torch.log(x), -100) | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
losses = -self.label_weights * ( | |
target * self._log(input) + (1 - target) * self._log(1 - input) | |
) | |
return self.reduction(losses) | |
# TODO: Code a onehot | |
def calculate_metrics( | |
pred, target, threshold=0.5, prefix="", multi_label=True | |
) -> dict[str, torch.Tensor]: | |
target = target.detach().cpu().numpy() | |
pred = pred.detach().cpu() | |
if not multi_label: | |
pred = nn.functional.softmax(pred, dim=1) | |
pred = pred.numpy() | |
params = { | |
"y_true": np.array(target > 0.0, dtype=float) | |
if multi_label | |
else target.argmax(1), | |
"y_pred": np.array(pred > threshold, dtype=float) | |
if multi_label | |
else pred.argmax(1), | |
"zero_division": 0, | |
"average": "macro", | |
} | |
metrics = { | |
"precision": precision_score(**params), | |
"recall": recall_score(**params), | |
"f1": f1_score(**params), | |
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]), | |
} | |
return { | |
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items() | |
} | |
class EarlyStopping: | |
def __init__(self, patience=0): | |
self.patience = patience | |
self.last_measure = np.inf | |
self.consecutive_increase = 0 | |
def step(self, val) -> bool: | |
if self.last_measure <= val: | |
self.consecutive_increase += 1 | |
else: | |
self.consecutive_increase = 0 | |
self.last_measure = val | |
return self.patience < self.consecutive_increase | |
def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]: | |
id2label = {str(i): label for i, label in enumerate(labels)} | |
label2id = {label: str(i) for i, label in enumerate(labels)} | |
return id2label, label2id | |
def compute_hf_metrics(eval_pred): | |
predictions = np.argmax(eval_pred.predictions, axis=1) | |
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions) | |
def get_dance_mapping(mapping_file: str) -> dict[str, str]: | |
mapping_df = pd.read_csv(mapping_file) | |
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()} | |
def plot_to_image(figure) -> np.ndarray: | |
"""Converts the matplotlib plot specified by 'figure' to a PNG image and | |
returns it. The supplied figure is closed and inaccessible after this call.""" | |
# Save the plot to a PNG in memory. | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
# Closing the figure prevents it from being displayed directly inside | |
# the notebook. | |
plt.close(figure) | |
buf.seek(0) | |
return np.array(Image.open(buf)) | |