marik0 commited on
Commit
889be89
·
1 Parent(s): f7996e9

Add sliders for dataset

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -9,11 +9,11 @@ import numpy as np
9
  import pandas as pd
10
  import matplotlib.pyplot as plt
11
 
12
- def create_dataset():
13
  X, y = make_classification(
14
- n_samples=1000,
15
  n_features=10,
16
- n_informative=3,
17
  n_redundant=0,
18
  n_repeated=0,
19
  n_classes=2,
@@ -24,17 +24,6 @@ def create_dataset():
24
  X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
25
  return X_train, X_test, y_train, y_test
26
 
27
- def train_model():
28
-
29
- X_train, X_test, y_train, y_test = create_dataset()
30
-
31
- feature_names = [f"feature {i}" for i in range(X_train.shape[1])]
32
- forest = RandomForestClassifier(random_state=0)
33
- forest.fit(X_train, y_train)
34
-
35
- return forest, feature_names, X_test, y_test
36
-
37
-
38
  def plot_mean_decrease(clf, feature_names):
39
  importances = clf.feature_importances_
40
  std = np.std([tree.feature_importances_ for tree in clf.estimators_], axis=0)
@@ -63,11 +52,25 @@ def plot_feature_perm(clf, feature_names, X_test, y_test):
63
 
64
  return fig
65
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  title = "Feature importances with a forest of trees 🌳"
69
  description = """This example shows the use of a forest of trees to evaluate the importance of features on an artificial classification task.
70
  The blue bars are the feature importances of the forest, along with their inter-trees variability represented by the error bars.
 
 
71
  """
72
 
73
  with gr.Blocks() as demo:
@@ -75,18 +78,16 @@ with gr.Blocks() as demo:
75
  gr.Markdown(description)
76
 
77
  # with gr.Column():
78
- clf, feature_names, X_test, y_test = train_model()
79
 
 
 
 
80
  with gr.Row():
81
- plot = gr.Plot(plot_mean_decrease(clf, feature_names))
82
- plot2 = gr.Plot(plot_feature_perm(clf, feature_names, X_test, y_test))
83
 
84
- # input_data = gr.Dropdown(choices=feature_names, label="Feature", value="body-mass index")
85
- # coef = gr.Textbox(label="Coefficients")
86
- # mse = gr.Textbox(label="Mean squared error (MSE)")
87
- # r2 = gr.Textbox(label="R2 score")
88
-
89
- # input_data.change(fn=train_model, inputs=[input_data], outputs=[plot, coef, mse, r2], queue=False)
90
 
91
 
92
  demo.launch(enable_queue=True)
 
9
  import pandas as pd
10
  import matplotlib.pyplot as plt
11
 
12
+ def create_dataset(num_samples, num_informative):
13
  X, y = make_classification(
14
+ n_samples=num_samples,
15
  n_features=10,
16
+ n_informative=num_informative,
17
  n_redundant=0,
18
  n_repeated=0,
19
  n_classes=2,
 
24
  X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
25
  return X_train, X_test, y_train, y_test
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def plot_mean_decrease(clf, feature_names):
28
  importances = clf.feature_importances_
29
  std = np.std([tree.feature_importances_ for tree in clf.estimators_], axis=0)
 
52
 
53
  return fig
54
 
55
+ def train_model(num_samples, num_info):
56
 
57
+ X_train, X_test, y_train, y_test = create_dataset(num_samples, num_info)
58
+
59
+ feature_names = [f"feature {i}" for i in range(X_train.shape[1])]
60
+ forest = RandomForestClassifier(random_state=0)
61
+ forest.fit(X_train, y_train)
62
+
63
+ fig = plot_mean_decrease(forest, feature_names)
64
+ fig2 = plot_feature_perm(forest, feature_names, X_test, y_test)
65
+ return fig, fig2
66
+
67
+
68
 
69
  title = "Feature importances with a forest of trees 🌳"
70
  description = """This example shows the use of a forest of trees to evaluate the importance of features on an artificial classification task.
71
  The blue bars are the feature importances of the forest, along with their inter-trees variability represented by the error bars.
72
+
73
+ The model is trained with simulated data.
74
  """
75
 
76
  with gr.Blocks() as demo:
 
78
  gr.Markdown(description)
79
 
80
  # with gr.Column():
 
81
 
82
+ num_samples = gr.Slider(minimum=1000, maximum=5000, step=500, value=1000, label="Number of samples")
83
+ num_info = gr.Slider(minimum=2, maximum=10, step=1, value=3, label="Number of informative features")
84
+
85
  with gr.Row():
86
+ plot = gr.Plot()
87
+ plot2 = gr.Plot()
88
 
89
+ num_samples.change(fn=train_model, inputs=[num_samples, num_info], outputs=[plot, plot2])
90
+ num_info.change(fn=train_model, inputs=[num_samples, num_info], outputs=[plot, plot2])
 
 
 
 
91
 
92
 
93
  demo.launch(enable_queue=True)