|
import json |
|
import math |
|
import os |
|
from typing import List, Optional |
|
|
|
from transformers.trainer import TRAINER_STATE_NAME |
|
|
|
from .logging import get_logger |
|
from .packages import is_matplotlib_available |
|
|
|
|
|
if is_matplotlib_available(): |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def smooth(scalars: List[float]) -> List[float]: |
|
r""" |
|
EMA implementation according to TensorBoard. |
|
""" |
|
last = scalars[0] |
|
smoothed = list() |
|
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) |
|
for next_val in scalars: |
|
smoothed_val = last * weight + (1 - weight) * next_val |
|
smoothed.append(smoothed_val) |
|
last = smoothed_val |
|
return smoothed |
|
|
|
|
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: |
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
|
|
for key in keys: |
|
steps, metrics = [], [] |
|
for i in range(len(data["log_history"])): |
|
if key in data["log_history"][i]: |
|
steps.append(data["log_history"][i]["step"]) |
|
metrics.append(data["log_history"][i][key]) |
|
|
|
if len(metrics) == 0: |
|
logger.warning(f"No metric {key} to plot.") |
|
continue |
|
|
|
plt.figure() |
|
plt.plot(steps, metrics, alpha=0.4, label="original") |
|
plt.plot(steps, smooth(metrics), label="smoothed") |
|
plt.title("training {} of {}".format(key, save_dictionary)) |
|
plt.xlabel("step") |
|
plt.ylabel(key) |
|
plt.legend() |
|
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) |
|
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) |
|
|