Eachan Johnson commited on
Commit
5f87e71
·
1 Parent(s): 11292cd

Specify init params

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -20,6 +20,7 @@ DATA_PATH = f'{os.getcwd()}/stennett2022-table1.xlsx'
20
  CLASS_COL = 'class'
21
  YEAR_COL = 'year'
22
  FIG_PANEL_SIZE = 3.5
 
23
 
24
  NUMBER_DISCOVERED = "Discovered"
25
  NUMBER_WO_R = "Without resistance"
@@ -106,7 +107,7 @@ def dobj_fun(
106
 
107
  def fit_to_data(
108
  df: pd.DataFrame,
109
- init_params: ArrayLike
110
  ) -> Tuple[float]:
111
  init_params = np.asarray([float(p) for p in init_params])
112
  print_err(f"Fitting with init params = {init_params}")
@@ -117,6 +118,7 @@ def fit_to_data(
117
  function_to_minimize,
118
  x0=init_params,
119
  jac=jacobian,
 
120
  )
121
  print(do)
122
  return tuple(do.x.flatten())
@@ -368,8 +370,8 @@ with gr.Blocks() as demo:
368
  )
369
 
370
  (fit_button
371
- .click(lambda *x: fit_to_data(data, init_params=x), inputs=param_sliders, outputs=param_sliders)
372
- .then(lambda *x: plot_data_altair(df=data, params=x),inputs=param_sliders, outputs=plot)
373
  .then(lambda *x: plot_data_forecast_altair(df=data, params=x),inputs=param_and_forecast_sliders, outputs=forecast))
374
 
375
  (demo.load(lambda *x: plot_data_altair(df=data, params=x),inputs=param_sliders, outputs=plot).then(lambda *x: plot_data_forecast_altair(df=data, params=x),inputs=param_and_forecast_sliders, outputs=forecast))
 
20
  CLASS_COL = 'class'
21
  YEAR_COL = 'year'
22
  FIG_PANEL_SIZE = 3.5
23
+ INIT_PARAMS = (38.534, 1.743, 27.68, 12.1) # To speed up fitting on Spaces
24
 
25
  NUMBER_DISCOVERED = "Discovered"
26
  NUMBER_WO_R = "Without resistance"
 
107
 
108
  def fit_to_data(
109
  df: pd.DataFrame,
110
+ init_params: ArrayLike = INIT_PARAMS
111
  ) -> Tuple[float]:
112
  init_params = np.asarray([float(p) for p in init_params])
113
  print_err(f"Fitting with init params = {init_params}")
 
118
  function_to_minimize,
119
  x0=init_params,
120
  jac=jacobian,
121
+ method="bfgs",
122
  )
123
  print(do)
124
  return tuple(do.x.flatten())
 
370
  )
371
 
372
  (fit_button
373
+ .click(lambda *x: fit_to_data(data), inputs=None, outputs=param_sliders)
374
+ .then(lambda *x: plot_data_altair(df=data, params=x), inputs=param_sliders, outputs=plot)
375
  .then(lambda *x: plot_data_forecast_altair(df=data, params=x),inputs=param_and_forecast_sliders, outputs=forecast))
376
 
377
  (demo.load(lambda *x: plot_data_altair(df=data, params=x),inputs=param_sliders, outputs=plot).then(lambda *x: plot_data_forecast_altair(df=data, params=x),inputs=param_and_forecast_sliders, outputs=forecast))