Michelle Lam commited on
Commit
c55019b
·
1 Parent(s): 4adf2d3

Refreshes overall plot when model is edited. Uses user ID in preds_df. Exports reports and chart to json. Streamlines performance plot and description.

Browse files
audit_utils.py CHANGED
@@ -139,7 +139,7 @@ def setup_user_model_dirs(cur_user, cur_model):
139
  # Charts
140
  def get_chart_file(cur_user, cur_model):
141
  chart_dir = f"./data/output/{cur_user}/{cur_model}"
142
- return os.path.join(chart_dir, f"chart__overall_vis.pkl")
143
 
144
  # Labels
145
  def get_label_dir(cur_user, cur_model):
@@ -174,7 +174,7 @@ def get_preds_file(cur_user, cur_model):
174
 
175
  # Reports
176
  def get_reports_file(cur_user, cur_model):
177
- return f"./data/output/{cur_user}/{cur_model}/reports.pkl"
178
 
179
  ########################################
180
  # General utils
@@ -236,14 +236,14 @@ def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
236
  return (bar + rule).interactive()
237
 
238
  # Generates the summary plot across all topics for the user
239
- def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median"):
240
  # Your perf (calculate using model and testset)
241
  preds_file = get_preds_file(cur_user, cur_model)
242
  with open(preds_file, "rb") as f:
243
  preds_df = pickle.load(f)
244
 
245
  chart_file = get_chart_file(cur_user, cur_model)
246
- if os.path.isfile(chart_file):
247
  # Read from file if it exists
248
  with open(chart_file, "r") as f:
249
  topic_overview_plot_json = json.load(f)
@@ -254,6 +254,9 @@ def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD
254
  elif topic_vis_method == "mean":
255
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
256
  topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model)
 
 
 
257
 
258
  return {
259
  "topic_overview_plot_json": json.loads(topic_overview_plot_json),
@@ -345,7 +348,7 @@ def fetch_existing_data(user, model_name):
345
  # - topic: topic to train on (used when tuning for a specific topic)
346
  def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False):
347
  # Check if there is previously-labeled data; if so, combine it with this data
348
- labeled_df = format_labeled_data(ratings) # Treat ratings as full batch of all ratings
349
  ratings_prev = None
350
 
351
  # Filter out rows with "unsure" (-1)
@@ -362,7 +365,7 @@ def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug
362
  label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
363
  with open(label_file, "rb") as f:
364
  ratings_prev = pickle.load(f)
365
- labeled_df_prev = format_labeled_data(ratings_prev)
366
  labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
367
  ratings.update(ratings_prev) # append old ratings to ratings
368
  labeled_df = pd.concat([labeled_df_prev, labeled_df])
@@ -377,23 +380,26 @@ def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug
377
  cur_model, _, _, _ = train_user_model(ratings_df=labeled_df)
378
 
379
  # Compute performance metrics
380
- mae, mse, rmse, avg_diff = users_perf(cur_model)
381
  # Save performance metrics
382
  perf_file = get_perf_file(user, model_name)
383
  with open(perf_file, "wb") as f:
384
  pickle.dump((mae, mse, rmse, avg_diff), f)
385
 
386
  # Pre-compute predictions for full dataset
387
- cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full)
388
  # Save pre-computed predictions
389
  preds_file = get_preds_file(user, model_name)
390
  with open(preds_file, "wb") as f:
391
  pickle.dump(cur_preds_df, f)
392
 
 
 
 
393
  ratings_prev = ratings
394
  return mae, mse, rmse, avg_diff, ratings_prev
395
 
396
- def format_labeled_data(ratings, worker_id="A"):
397
  all_rows = []
398
  for comment, rating in ratings.items():
399
  comment_id = comments_to_ids[comment]
@@ -403,7 +409,7 @@ def format_labeled_data(ratings, worker_id="A"):
403
  df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
404
  return df
405
 
406
- def users_perf(model, sys_eval_df=sys_eval_df, worker_id="A"):
407
  # Load the full empty dataset
408
  sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
409
  empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
@@ -423,7 +429,7 @@ def users_perf(model, sys_eval_df=sys_eval_df, worker_id="A"):
423
  df.dropna(subset = ["pred"], inplace=True)
424
  df["rating"] = df.rating.astype("int32")
425
 
426
- perf_metrics = get_overall_perf(df, "A") # mae, mse, rmse, avg_diff
427
  return perf_metrics
428
 
429
  def get_overall_perf(preds_df, user_id):
@@ -565,24 +571,24 @@ def plot_train_perf_results(user, model_name, mae):
565
  width=500,
566
  )
567
 
568
- PCT_50 = 0.591
569
- PCT_75 = 0.662
570
- PCT_90 = 0.869
571
 
572
  plot_dim_width = 500
573
  domain_min = 0.0
574
  domain_max = 2.0
575
  bkgd = alt.Chart(pd.DataFrame({
576
- "start": [PCT_90, PCT_75, domain_min],
577
- "stop": [domain_max, PCT_90, PCT_75],
578
- "bkgd": ["Needs improvement (< top 90%)", "Okay (top 90%)", "Good (top 75%)"],
579
  })).mark_rect(opacity=0.2).encode(
580
- y=alt.Y("start:Q", scale=alt.Scale(domain=[0, domain_max])),
581
- y2=alt.Y2("stop:Q"),
582
  x=alt.value(0),
583
  x2=alt.value(plot_dim_width),
584
  color=alt.Color("bkgd:O", scale=alt.Scale(
585
- domain=["Needs improvement (< top 90%)", "Okay (top 90%)", "Good (top 75%)"],
586
  range=["red", "yellow", "green"]),
587
  title="How good is your MAE?"
588
  )
@@ -590,12 +596,12 @@ def plot_train_perf_results(user, model_name, mae):
590
 
591
  plot = (bkgd + chart).properties(width=plot_dim_width).resolve_scale(color='independent')
592
  mae_status = None
593
- if mae < PCT_75:
594
- mae_status = "Your MAE is in the <b>Good</b> range, which means that it's in the top 75% of scores compared to other users. Your model looks good to go."
595
- elif mae < PCT_90:
596
- mae_status = "Your MAE is in the <b>Okay</b> range, which means that it's in the top 90% of scores compared to other users. Your model can be used, but you can provide additional labels to improve it."
597
  else:
598
- mae_status = "Your MAE is in the <b>Needs improvement</b> range, which means that it's in below the top 95% of scores compared to other users. Your model may need additional labels to improve."
599
  return plot, mae_status
600
 
601
  ########################################
@@ -724,7 +730,7 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
724
  df = df[df["topic_id"] < n_topics]
725
 
726
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
727
- df = df[df["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
728
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
729
  df["threshold"] = [threshold for r in df[sys_col].tolist()]
730
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
@@ -824,21 +830,15 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
824
  )
825
 
826
  plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
827
-
828
- # Save to file
829
- chart_file = get_chart_file(cur_user, cur_model)
830
- with open(chart_file, "w") as f:
831
- json.dump(plot, f)
832
-
833
  return plot
834
 
835
  # Plots cluster results histogram (each block is a comment), but *without* a model
836
  # as a point of reference (in contrast to plot_overall_vis_cluster)
837
- def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
838
  df = preds_df.copy().reset_index()
839
 
840
  df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True)
841
- df = df[df["user_id"] == "A"].sort_values(by=[sys_col]).reset_index()
842
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
843
  df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()]
844
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
@@ -930,11 +930,11 @@ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS,
930
  return final_plot, df
931
 
932
  # Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
933
- def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
934
  df = preds_df.copy().reset_index()
935
 
936
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
937
- df = df[df["user_id"] == "A"].sort_values(by=[sys_col]).reset_index(drop=True)
938
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
939
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
940
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
 
139
  # Charts
140
  def get_chart_file(cur_user, cur_model):
141
  chart_dir = f"./data/output/{cur_user}/{cur_model}"
142
+ return os.path.join(chart_dir, f"chart_overall_vis.json")
143
 
144
  # Labels
145
  def get_label_dir(cur_user, cur_model):
 
174
 
175
  # Reports
176
  def get_reports_file(cur_user, cur_model):
177
+ return f"./data/output/{cur_user}/{cur_model}/reports.json"
178
 
179
  ########################################
180
  # General utils
 
236
  return (bar + rule).interactive()
237
 
238
  # Generates the summary plot across all topics for the user
239
+ def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median", use_cache=True):
240
  # Your perf (calculate using model and testset)
241
  preds_file = get_preds_file(cur_user, cur_model)
242
  with open(preds_file, "rb") as f:
243
  preds_df = pickle.load(f)
244
 
245
  chart_file = get_chart_file(cur_user, cur_model)
246
+ if use_cache and os.path.isfile(chart_file):
247
  # Read from file if it exists
248
  with open(chart_file, "r") as f:
249
  topic_overview_plot_json = json.load(f)
 
254
  elif topic_vis_method == "mean":
255
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
256
  topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model)
257
+ # Save to file
258
+ with open(chart_file, "w") as f:
259
+ json.dump(topic_overview_plot_json, f)
260
 
261
  return {
262
  "topic_overview_plot_json": json.loads(topic_overview_plot_json),
 
348
  # - topic: topic to train on (used when tuning for a specific topic)
349
  def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False):
350
  # Check if there is previously-labeled data; if so, combine it with this data
351
+ labeled_df = format_labeled_data(ratings, worker_id=user) # Treat ratings as full batch of all ratings
352
  ratings_prev = None
353
 
354
  # Filter out rows with "unsure" (-1)
 
365
  label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
366
  with open(label_file, "rb") as f:
367
  ratings_prev = pickle.load(f)
368
+ labeled_df_prev = format_labeled_data(ratings_prev, worker_id=user)
369
  labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
370
  ratings.update(ratings_prev) # append old ratings to ratings
371
  labeled_df = pd.concat([labeled_df_prev, labeled_df])
 
380
  cur_model, _, _, _ = train_user_model(ratings_df=labeled_df)
381
 
382
  # Compute performance metrics
383
+ mae, mse, rmse, avg_diff = users_perf(cur_model, worker_id=user)
384
  # Save performance metrics
385
  perf_file = get_perf_file(user, model_name)
386
  with open(perf_file, "wb") as f:
387
  pickle.dump((mae, mse, rmse, avg_diff), f)
388
 
389
  # Pre-compute predictions for full dataset
390
+ cur_preds_df = get_preds_df(cur_model, [user], sys_eval_df=ratings_df_full)
391
  # Save pre-computed predictions
392
  preds_file = get_preds_file(user, model_name)
393
  with open(preds_file, "wb") as f:
394
  pickle.dump(cur_preds_df, f)
395
 
396
+ # Replace cached summary plot if it exists
397
+ show_overall_perf(cur_model=model_name, error_type="Both", cur_user=user, use_cache=False)
398
+
399
  ratings_prev = ratings
400
  return mae, mse, rmse, avg_diff, ratings_prev
401
 
402
+ def format_labeled_data(ratings, worker_id):
403
  all_rows = []
404
  for comment, rating in ratings.items():
405
  comment_id = comments_to_ids[comment]
 
409
  df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
410
  return df
411
 
412
+ def users_perf(model, worker_id, sys_eval_df=sys_eval_df):
413
  # Load the full empty dataset
414
  sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
415
  empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
 
429
  df.dropna(subset = ["pred"], inplace=True)
430
  df["rating"] = df.rating.astype("int32")
431
 
432
+ perf_metrics = get_overall_perf(df, worker_id) # mae, mse, rmse, avg_diff
433
  return perf_metrics
434
 
435
  def get_overall_perf(preds_df, user_id):
 
571
  width=500,
572
  )
573
 
574
+ # Manually set for now
575
+ mae_good = 1.0
576
+ mae_okay = 1.2
577
 
578
  plot_dim_width = 500
579
  domain_min = 0.0
580
  domain_max = 2.0
581
  bkgd = alt.Chart(pd.DataFrame({
582
+ "start": [mae_okay, mae_good, domain_min],
583
+ "stop": [domain_max, mae_okay, mae_good],
584
+ "bkgd": ["Needs improvement", "Okay", "Good"],
585
  })).mark_rect(opacity=0.2).encode(
586
+ y=alt.Y("start:Q", scale=alt.Scale(domain=[0, domain_max]), title=""),
587
+ y2=alt.Y2("stop:Q", title="Performance (MAE)"),
588
  x=alt.value(0),
589
  x2=alt.value(plot_dim_width),
590
  color=alt.Color("bkgd:O", scale=alt.Scale(
591
+ domain=["Needs improvement", "Okay", "Good"],
592
  range=["red", "yellow", "green"]),
593
  title="How good is your MAE?"
594
  )
 
596
 
597
  plot = (bkgd + chart).properties(width=plot_dim_width).resolve_scale(color='independent')
598
  mae_status = None
599
+ if mae < mae_good:
600
+ mae_status = "Your MAE is in the <b>Good</b> range. Your model looks ready to go."
601
+ elif mae < mae_okay:
602
+ mae_status = "Your MAE is in the <b>Okay</b> range. Your model can be used, but you can provide additional labels to improve it."
603
  else:
604
+ mae_status = "Your MAE is in the <b>Needs improvement</b> range. Your model may need additional labels to improve."
605
  return plot, mae_status
606
 
607
  ########################################
 
730
  df = df[df["topic_id"] < n_topics]
731
 
732
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
733
+ df = df[df["user_id"] == cur_user].sort_values(by=["item_id"]).reset_index()
734
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
735
  df["threshold"] = [threshold for r in df[sys_col].tolist()]
736
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
 
830
  )
831
 
832
  plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
 
 
 
 
 
 
833
  return plot
834
 
835
  # Plots cluster results histogram (each block is a comment), but *without* a model
836
  # as a point of reference (in contrast to plot_overall_vis_cluster)
837
+ def plot_overall_vis_cluster_no_model(cur_user, preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
838
  df = preds_df.copy().reset_index()
839
 
840
  df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True)
841
+ df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index()
842
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
843
  df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()]
844
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
 
930
  return final_plot, df
931
 
932
  # Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
933
+ def plot_overall_vis_cluster(cur_user, preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
934
  df = preds_df.copy().reset_index()
935
 
936
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
937
+ df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index(drop=True)
938
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
939
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
940
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
indie_label_svelte/src/Auditing.svelte CHANGED
@@ -158,12 +158,15 @@
158
 
159
  function handleAuditButton() {
160
  model_chosen.update((value) => personalized_model);
161
- promise = getAudit();
 
 
 
162
  }
163
 
164
- async function getAudit() {
165
  let req_params = {
166
- pers_model: personalized_model,
167
  perf_metric: "avg_diff",
168
  breakdown_sort: "difference",
169
  n_topics: 10,
@@ -179,18 +182,18 @@
179
  }
180
 
181
  function handleClusterButton() {
182
- promise_cluster = getCluster();
183
  }
184
 
185
- async function getCluster() {
186
- if (personalized_model == "" || personalized_model == undefined) {
187
  return null;
188
  }
189
  let req_params = {
190
  cluster: topic,
191
  topic_df_ids: [],
192
  cur_user: cur_user,
193
- pers_model: personalized_model,
194
  example_sort: "descending", // TEMP
195
  comparison_group: "status_quo", // TEMP
196
  search_type: "cluster",
 
158
 
159
  function handleAuditButton() {
160
  model_chosen.update((value) => personalized_model);
161
+ if (personalized_model == "" || personalized_model == undefined) {
162
+ return;
163
+ }
164
+ promise = getAudit(personalized_model);
165
  }
166
 
167
+ async function getAudit(pers_model) {
168
  let req_params = {
169
+ pers_model: pers_model,
170
  perf_metric: "avg_diff",
171
  breakdown_sort: "difference",
172
  n_topics: 10,
 
182
  }
183
 
184
  function handleClusterButton() {
185
+ promise_cluster = getCluster(personalized_model);
186
  }
187
 
188
+ async function getCluster(pers_model) {
189
+ if (pers_model == "" || pers_model == undefined) {
190
  return null;
191
  }
192
  let req_params = {
193
  cluster: topic,
194
  topic_df_ids: [],
195
  cur_user: cur_user,
196
+ pers_model: pers_model,
197
  example_sort: "descending", // TEMP
198
  comparison_group: "status_quo", // TEMP
199
  search_type: "cluster",
indie_label_svelte/src/CommentTable.svelte CHANGED
@@ -88,11 +88,14 @@
88
  user: cur_user,
89
  };
90
  let params = new URLSearchParams(req_params).toString();
91
- const response = await fetch("./get_personalized_model?" + params);
92
- const text = await response.text();
93
- const data = JSON.parse(text);
94
- to_label = data["ratings_prev"];
95
- model_chosen.update((value) => model_name);
 
 
 
96
  return data;
97
  }
98
  </script>
 
88
  user: cur_user,
89
  };
90
  let params = new URLSearchParams(req_params).toString();
91
+ const data = await fetch("./get_personalized_model?" + params)
92
+ .then((r) => r.text())
93
+ .then(function (text) {
94
+ let data = JSON.parse(text);
95
+ to_label = data["ratings_prev"];
96
+ model_chosen.update((value) => model_name);
97
+ return data;
98
+ });
99
  return data;
100
  }
101
  </script>
indie_label_svelte/src/ModelPerf.svelte CHANGED
@@ -1,9 +1,7 @@
1
  <script lang="ts">
2
  import { VegaLite } from "svelte-vega";
3
  import type { View } from "svelte-vega";
4
-
5
  import LayoutGrid, { Cell } from "@smui/layout-grid";
6
- import Card, { Content } from '@smui/card';
7
 
8
  export let data;
9
 
@@ -13,64 +11,25 @@
13
  ];
14
  let perf_plot_view: View;
15
 
16
- // let perf_plot2_spec = data["perf_plot2_json"];
17
- // let perf_plot2_data = perf_plot2_spec["datasets"][perf_plot2_spec["data"]["name"]];
18
- // let perf_plot2_view: View;
19
  </script>
20
 
21
  <div>
22
  <h6>Your Model Performance</h6>
23
- <LayoutGrid>
24
- <Cell span={8}>
25
- <div class="card-container">
26
- <Card variant="outlined" padded>
27
- <p class="mdc-typography--button"><b>Interpreting your model performance</b></p>
28
- <ul>
29
- <li>
30
- The <b>Mean Absolute Error (MAE)</b> metric indicates the average absolute difference between your model's rating and your actual rating on a held-out set of comments.
31
- </li>
32
- <li>
33
- You want your model to have a <b>lower</b> MAE (indicating <b>less error</b>).
34
- </li>
35
- <li>
36
- <b>Your current MAE: {data["mae"]}</b>
37
- <ul>
38
- <li>{@html data["mae_status"]}</li>
39
- <!-- <li>
40
- This is <b>better</b> (lower) than the average MAE for other users, so your model appears to <b>better capture</b> your views than the typical user model.
41
- </li> -->
42
- </ul>
43
- </li>
44
- </ul>
45
- </Card>
46
- </div>
47
- </Cell>
48
- </LayoutGrid>
49
  <div>
50
- <!-- Overall -->
51
- <!-- <table>
52
- <tbody>
53
- <tr>
54
- <td>
55
- <span class="bold">Mean Absolute Error (MAE)</span><br>
56
-
57
- </td>
58
- <td>
59
- <span class="bold-large">{data["mae"]}</span>
60
- </td>
61
- </tr>
62
- <tr>
63
- <td>
64
- <span class="bold">Average rating difference</span><br>
65
- This metric indicates the average difference between your model's rating and your actual rating on a held-out set of comments.
66
- </td>
67
- <td>
68
- <span class="bold-large">{data["avg_diff"]}</span>
69
- </td>
70
- </tr>
71
- </tbody>
72
- </table> -->
73
-
74
  <!-- Performance visualization -->
75
  <div>
76
  <VegaLite {perf_plot_data} spec={perf_plot_spec} bind:view={perf_plot_view}/>
 
1
  <script lang="ts">
2
  import { VegaLite } from "svelte-vega";
3
  import type { View } from "svelte-vega";
 
4
  import LayoutGrid, { Cell } from "@smui/layout-grid";
 
5
 
6
  export let data;
7
 
 
11
  ];
12
  let perf_plot_view: View;
13
 
 
 
 
14
  </script>
15
 
16
  <div>
17
  <h6>Your Model Performance</h6>
18
+ <ul>
19
+ <li>
20
+ The <b>Mean Absolute Error (MAE)</b> metric indicates the average absolute difference <br>between your model's rating and your actual rating on a held-out set of comments.
21
+ </li>
22
+ <li>
23
+ You want your model to have a <b>lower</b> MAE (indicating <b>less error</b>).
24
+ </li>
25
+ <li>
26
+ <b>Your current MAE: {data["mae"]}</b>
27
+ <ul>
28
+ <li>{@html data["mae_status"]}</li>
29
+ </ul>
30
+ </li>
31
+ </ul>
 
 
 
 
 
 
 
 
 
 
 
 
32
  <div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  <!-- Performance visualization -->
34
  <div>
35
  <VegaLite {perf_plot_data} spec={perf_plot_spec} bind:view={perf_plot_view}/>
server.py CHANGED
@@ -166,10 +166,10 @@ def get_cluster_results(debug=DEBUG):
166
  # Prepare overview plot for the cluster
167
  if use_model:
168
  # Display results with the model as a reference point
169
- cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
170
  else:
171
  # Display results without a model
172
- cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(topic_df, n_comments=500)
173
 
174
  cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, use_model=use_model) # New version of cluster comment table
175
 
@@ -428,7 +428,7 @@ def get_reports():
428
  else:
429
  # Load from pickle file
430
  with open(reports_file, "rb") as f:
431
- reports = pickle.load(f)
432
 
433
  results = {
434
  "reports": reports,
@@ -538,7 +538,7 @@ def get_personal_scaffold(cur_user, model, topic_vis_method, n_topics=200, n=5):
538
  preds_df = pickle.load(f)
539
  system_preds_df = utils.get_system_preds_df()
540
  preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
541
- preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
542
  preds_df_mod = preds_df_mod[preds_df_mod["topic_id"] < n_topics]
543
 
544
  if topic_vis_method == "median":
@@ -643,8 +643,8 @@ def save_reports():
643
 
644
  # Save reports for current user to file
645
  reports_file = utils.get_reports_file(cur_user, model)
646
- with open(reports_file, "wb") as f:
647
- pickle.dump(reports, f)
648
 
649
  results = {
650
  "status": "success",
 
166
  # Prepare overview plot for the cluster
167
  if use_model:
168
  # Display results with the model as a reference point
169
+ cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(cur_user, topic_df, error_type=error_type, n_comments=500)
170
  else:
171
  # Display results without a model
172
+ cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(cur_user, topic_df, n_comments=500)
173
 
174
  cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, use_model=use_model) # New version of cluster comment table
175
 
 
428
  else:
429
  # Load from pickle file
430
  with open(reports_file, "rb") as f:
431
+ reports = json.load(f)
432
 
433
  results = {
434
  "reports": reports,
 
538
  preds_df = pickle.load(f)
539
  system_preds_df = utils.get_system_preds_df()
540
  preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
541
+ preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == cur_user].sort_values(by=["item_id"]).reset_index()
542
  preds_df_mod = preds_df_mod[preds_df_mod["topic_id"] < n_topics]
543
 
544
  if topic_vis_method == "median":
 
643
 
644
  # Save reports for current user to file
645
  reports_file = utils.get_reports_file(cur_user, model)
646
+ with open(reports_file, "w", encoding ='utf8') as f:
647
+ json.dump(reports, f)
648
 
649
  results = {
650
  "status": "success",