chainyo commited on
Commit
8b9d8ef
·
1 Parent(s): 8dc61d0

fix distplots

Browse files
Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -82,7 +82,9 @@ def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs.
82
  """
83
  data = pd.DataFrame.from_dict(timers, orient="columns")
84
  colors = ["#84353f", "#b4524b", "#f47e58", "#ffbe67"]
85
- fig = ff.create_distplot([data[col] for col in data.columns], data.columns, bin_size=0.2, colors=colors)
 
 
86
  fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples")
87
  return fig
88
 
@@ -196,7 +198,7 @@ st.session_state["init_models"] = False
196
  if "inference_timers" not in st.session_state:
197
  st.session_state["inference_timers"] = {}
198
 
199
- exp_number = st.slider("The number of experiments per model.", min_value=100, max_value=300, value=150)
200
  get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False)
201
  input_text = st.text_area(
202
  "Enter text to classify",
 
82
  """
83
  data = pd.DataFrame.from_dict(timers, orient="columns")
84
  colors = ["#84353f", "#b4524b", "#f47e58", "#ffbe67"]
85
+ fig = ff.create_distplot(
86
+ [data[col] for col in data.columns], data.columns, bin_size=0.001, colors=colors, show_curve=False
87
+ )
88
  fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples")
89
  return fig
90
 
 
198
  if "inference_timers" not in st.session_state:
199
  st.session_state["inference_timers"] = {}
200
 
201
+ exp_number = st.slider("The number of experiments per model.", min_value=10, max_value=300, value=150)
202
  get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False)
203
  input_text = st.text_area(
204
  "Enter text to classify",