Spaces:
Running
Running
File size: 2,178 Bytes
b955f86 9fa2182 985f8fa 9fa2182 985f8fa b955f86 86e9755 967d63f 86e9755 9fa2182 c1a4fec 9fa2182 c52d151 9fa2182 c1a4fec c52d151 9fa2182 c52d151 84b46ac c52d151 9fa2182 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import logging
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
plt.ioff()
plt.rcParams["font.family"] = [
"IBM Plex Mono",
# Fallback fonts:
"DejaVu Sans Mono",
"Courier New",
"monospace",
]
logging.getLogger("matplotlib.font_manager").disabled = True
from data import generate_data
def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
if len(df) == 0 or "Equation" not in df.columns:
return fig
ax.loglog(
df["Complexity"],
df["Loss"],
marker="o",
linestyle="-",
color="#333f48",
linewidth=1.5,
markersize=6,
)
ax.set_xlim(0.5, maxsize + 1)
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
ax.set_ylim(ybottom, ytop)
stylize_axis(ax)
ax.set_xlabel("Complexity")
ax.set_ylabel("Loss")
fig.tight_layout(pad=2)
return fig
def plot_example_data(test_equation, num_points, noise_level, data_seed):
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
x = X["x"]
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
stylize_axis(ax)
ax.set_xlabel("x")
ax.set_ylabel("y")
fig.tight_layout(pad=2)
return fig
def plot_predictions(y, ypred):
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
ax.scatter(y, ypred, alpha=0.7, edgecolors="w", s=50)
stylize_axis(ax)
ax.set_xlabel("true")
ax.set_ylabel("prediction")
fig.tight_layout(pad=2)
return fig
def stylize_axis(ax):
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Range-frame the plot
for direction in ["bottom", "left"]:
ax.spines[direction].set_position(("outward", 10))
# Delete far ticks
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|