MilesCranmer commited on
Commit
c52d151
1 Parent(s): c1a4fec

Standardize figure settings

Browse files
Files changed (1) hide show
  1. gui/plots.py +13 -22
gui/plots.py CHANGED
@@ -20,7 +20,6 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
20
  if len(df) == 0 or "Equation" not in df.columns:
21
  return fig
22
 
23
- # Plotting the data
24
  ax.loglog(
25
  df["Complexity"],
26
  df["Loss"],
@@ -31,23 +30,12 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
31
  markersize=6,
32
  )
33
 
34
- # Set the axis limits
35
  ax.set_xlim(0.5, maxsize + 1)
36
  ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
37
  ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
38
  ax.set_ylim(ybottom, ytop)
39
 
40
- ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
41
- ax.spines["top"].set_visible(False)
42
- ax.spines["right"].set_visible(False)
43
-
44
- # Range-frame the plot
45
- for direction in ["bottom", "left"]:
46
- ax.spines[direction].set_position(("outward", 10))
47
-
48
- # Delete far ticks
49
- ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
50
- ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
51
 
52
  ax.set_xlabel("Complexity")
53
  ax.set_ylabel("Loss")
@@ -57,14 +45,23 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
57
 
58
 
59
  def plot_example_data(test_equation, num_points, noise_level, data_seed):
 
 
60
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
61
  x = X["x"]
62
 
63
- plt.rcParams["font.family"] = "IBM Plex Mono"
64
- fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
65
-
66
  ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
67
 
 
 
 
 
 
 
 
 
 
 
68
  ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
69
  ax.spines["top"].set_visible(False)
70
  ax.spines["right"].set_visible(False)
@@ -76,9 +73,3 @@ def plot_example_data(test_equation, num_points, noise_level, data_seed):
76
  # Delete far ticks
77
  ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
78
  ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
79
-
80
- ax.set_xlabel("x")
81
- ax.set_ylabel("y")
82
- fig.tight_layout(pad=2)
83
-
84
- return fig
 
20
  if len(df) == 0 or "Equation" not in df.columns:
21
  return fig
22
 
 
23
  ax.loglog(
24
  df["Complexity"],
25
  df["Loss"],
 
30
  markersize=6,
31
  )
32
 
 
33
  ax.set_xlim(0.5, maxsize + 1)
34
  ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
35
  ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
36
  ax.set_ylim(ybottom, ytop)
37
 
38
+ stylize_axis(ax)
 
 
 
 
 
 
 
 
 
 
39
 
40
  ax.set_xlabel("Complexity")
41
  ax.set_ylabel("Loss")
 
45
 
46
 
47
  def plot_example_data(test_equation, num_points, noise_level, data_seed):
48
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
49
+
50
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
51
  x = X["x"]
52
 
 
 
 
53
  ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
54
 
55
+ stylize_axis(ax)
56
+
57
+ ax.set_xlabel("x")
58
+ ax.set_ylabel("y")
59
+ fig.tight_layout(pad=2)
60
+
61
+ return fig
62
+
63
+
64
+ def stylize_axis(ax):
65
  ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
66
  ax.spines["top"].set_visible(False)
67
  ax.spines["right"].set_visible(False)
 
73
  # Delete far ticks
74
  ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
75
  ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)