File size: 2,178 Bytes
b955f86
 
9fa2182
 
985f8fa
9fa2182
985f8fa
 
 
 
 
 
 
 
b955f86
86e9755
967d63f
86e9755
9fa2182
c1a4fec
9fa2182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c52d151
9fa2182
 
 
 
 
 
 
 
c1a4fec
c52d151
 
9fa2182
 
 
 
 
c52d151
 
 
 
 
 
 
 
 
84b46ac
 
 
 
 
 
 
 
 
 
 
 
 
 
c52d151
9fa2182
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

plt.ioff()
plt.rcParams["font.family"] = [
    "IBM Plex Mono",
    # Fallback fonts:
    "DejaVu Sans Mono",
    "Courier New",
    "monospace",
]
logging.getLogger("matplotlib.font_manager").disabled = True

from data import generate_data


def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
    fig, ax = plt.subplots(figsize=(6, 6), dpi=100)

    if len(df) == 0 or "Equation" not in df.columns:
        return fig

    ax.loglog(
        df["Complexity"],
        df["Loss"],
        marker="o",
        linestyle="-",
        color="#333f48",
        linewidth=1.5,
        markersize=6,
    )

    ax.set_xlim(0.5, maxsize + 1)
    ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
    ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
    ax.set_ylim(ybottom, ytop)

    stylize_axis(ax)

    ax.set_xlabel("Complexity")
    ax.set_ylabel("Loss")
    fig.tight_layout(pad=2)

    return fig


def plot_example_data(test_equation, num_points, noise_level, data_seed):
    fig, ax = plt.subplots(figsize=(6, 6), dpi=100)

    X, y = generate_data(test_equation, num_points, noise_level, data_seed)
    x = X["x"]

    ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)

    stylize_axis(ax)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    fig.tight_layout(pad=2)

    return fig


def plot_predictions(y, ypred):
    fig, ax = plt.subplots(figsize=(6, 6), dpi=100)

    ax.scatter(y, ypred, alpha=0.7, edgecolors="w", s=50)

    stylize_axis(ax)

    ax.set_xlabel("true")
    ax.set_ylabel("prediction")
    fig.tight_layout(pad=2)

    return fig


def stylize_axis(ax):
    ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    # Range-frame the plot
    for direction in ["bottom", "left"]:
        ax.spines[direction].set_position(("outward", 10))

    # Delete far ticks
    ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
    ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)