Jayabalambika commited on
Commit
0afb817
·
1 Parent(s): 9a2dd60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -45
app.py CHANGED
@@ -35,57 +35,62 @@ def generate_plots(min_slider_samples_range,max_slider_samples_range):
35
  oa_mse[i, j] = oa.error_norm(real_cov, scaling=False)
36
  oa_shrinkage[i, j] = oa.shrinkage_
37
 
38
-
39
- # plot MSE
40
- plt.clf()
41
- plt.subplot(2, 1, 1)
42
- plt.errorbar(
43
- slider_samples_range,
44
- lw_mse.mean(1),
45
- yerr=lw_mse.std(1),
46
- label="Ledoit-Wolf",
47
- color="navy",
48
- lw=2,
49
- )
50
- plt.errorbar(
51
- slider_samples_range,
52
- oa_mse.mean(1),
53
- yerr=oa_mse.std(1),
54
- label="OAS",
55
- color="darkorange",
56
- lw=2,
57
- )
58
- plt.ylabel("Squared error")
59
- plt.legend(loc="upper right")
60
- plt.title("Comparison of covariance estimators")
61
- plt.xlim(5, 31)
62
-
63
- # plot shrinkage coefficient
64
- plt.subplot(2, 1, 2)
65
- plt.errorbar(
 
 
 
 
 
66
  slider_samples_range,
67
  lw_shrinkage.mean(1),
68
  yerr=lw_shrinkage.std(1),
69
  label="Ledoit-Wolf",
70
  color="navy",
71
  lw=2,
72
- )
73
- plt.errorbar(
74
  slider_samples_range,
75
  oa_shrinkage.mean(1),
76
  yerr=oa_shrinkage.std(1),
77
  label="OAS",
78
  color="darkorange",
79
  lw=2,
80
- )
81
- plt.xlabel("n_samples")
82
- plt.ylabel("Shrinkage")
83
- plt.legend(loc="lower right")
84
- plt.ylim(plt.ylim()[0], 1.0 + (plt.ylim()[1] - plt.ylim()[0]) / 10.0)
85
- plt.xlim(5, 31)
86
 
87
- # plt.show()
88
- return plt
89
 
90
 
91
 
@@ -118,8 +123,7 @@ with gr.Blocks(title=title, theme=gr.themes.Default(font=[gr.themes.GoogleFont("
118
 
119
 
120
  r = 0.1
121
- changed1 = False
122
- changed2 = False
123
  real_cov = toeplitz(r ** np.arange(n_features))
124
  coloring_matrix = cholesky(real_cov)
125
  gr.Markdown(" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/covariance/plot_lw_vs_oas.html)**")
@@ -128,12 +132,12 @@ with gr.Blocks(title=title, theme=gr.themes.Default(font=[gr.themes.GoogleFont("
128
  # output = gr.Textbox(label="Output Box")
129
  # greet_btn.click(fn=greet, inputs=name, outputs=output)
130
  gr.Label(value="Comparison of Covariance Estimators")
 
131
  #if min_slider_samples_range:
132
- while not (changed1 and changed2):
133
- min_slider_samples_range.change(generate_plots, inputs=[min_slider_samples_range,max_slider_samples_range], outputs= gr.Plot() )
134
- max_slider_samples_range.change(generate_plots, inputs=[min_slider_samples_range,max_slider_samples_range], outputs= gr.Plot() )
135
- changed1 = True
136
- changed2 = True
137
 
138
 
139
  #elif max_slider_samples_range:
 
35
  oa_mse[i, j] = oa.error_norm(real_cov, scaling=False)
36
  oa_shrinkage[i, j] = oa.shrinkage_
37
 
38
+
39
+
40
+ def plot_mse():
41
+ # plot MSE
42
+ plt.clf()
43
+ plt.subplot(2, 1, 1)
44
+ plt.errorbar(
45
+ slider_samples_range,
46
+ lw_mse.mean(1),
47
+ yerr=lw_mse.std(1),
48
+ label="Ledoit-Wolf",
49
+ color="navy",
50
+ lw=2,
51
+ )
52
+ plt.errorbar(
53
+ slider_samples_range,
54
+ oa_mse.mean(1),
55
+ yerr=oa_mse.std(1),
56
+ label="OAS",
57
+ color="darkorange",
58
+ lw=2,
59
+ )
60
+ plt.ylabel("Squared error")
61
+ plt.legend(loc="upper right")
62
+ plt.title("Comparison of covariance estimators")
63
+ plt.xlim(5, 31)
64
+ return plt
65
+
66
+
67
+ def plot_shrinkage():
68
+ # plot shrinkage coefficient
69
+ plt.subplot(2, 1, 2)
70
+ plt.errorbar(
71
  slider_samples_range,
72
  lw_shrinkage.mean(1),
73
  yerr=lw_shrinkage.std(1),
74
  label="Ledoit-Wolf",
75
  color="navy",
76
  lw=2,
77
+ )
78
+ plt.errorbar(
79
  slider_samples_range,
80
  oa_shrinkage.mean(1),
81
  yerr=oa_shrinkage.std(1),
82
  label="OAS",
83
  color="darkorange",
84
  lw=2,
85
+ )
86
+ plt.xlabel("n_samples")
87
+ plt.ylabel("Shrinkage")
88
+ plt.legend(loc="lower right")
89
+ plt.ylim(plt.ylim()[0], 1.0 + (plt.ylim()[1] - plt.ylim()[0]) / 10.0)
90
+ plt.xlim(5, 31)
91
 
92
+ # plt.show()
93
+ return plt
94
 
95
 
96
 
 
123
 
124
 
125
  r = 0.1
126
+
 
127
  real_cov = toeplitz(r ** np.arange(n_features))
128
  coloring_matrix = cholesky(real_cov)
129
  gr.Markdown(" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/covariance/plot_lw_vs_oas.html)**")
 
132
  # output = gr.Textbox(label="Output Box")
133
  # greet_btn.click(fn=greet, inputs=name, outputs=output)
134
  gr.Label(value="Comparison of Covariance Estimators")
135
+ generate_plots()
136
  #if min_slider_samples_range:
137
+
138
+ min_slider_samples_range.change(plot_mse, outputs= gr.Plot() )
139
+ max_slider_samples_range.change(plot_shrinkage, outputs= gr.Plot() )
140
+
 
141
 
142
 
143
  #elif max_slider_samples_range: