Spaces:
Runtime error
Runtime error
Michelle Lam
commited on
Commit
·
c55019b
1
Parent(s):
4adf2d3
Refreshes overall plot when model is edited. Uses user ID in preds_df. Exports reports and chart to json. Streamlines performance plot and description.
Browse files- audit_utils.py +36 -36
- indie_label_svelte/src/Auditing.svelte +10 -7
- indie_label_svelte/src/CommentTable.svelte +8 -5
- indie_label_svelte/src/ModelPerf.svelte +14 -55
- server.py +6 -6
audit_utils.py
CHANGED
@@ -139,7 +139,7 @@ def setup_user_model_dirs(cur_user, cur_model):
|
|
139 |
# Charts
|
140 |
def get_chart_file(cur_user, cur_model):
|
141 |
chart_dir = f"./data/output/{cur_user}/{cur_model}"
|
142 |
-
return os.path.join(chart_dir, f"
|
143 |
|
144 |
# Labels
|
145 |
def get_label_dir(cur_user, cur_model):
|
@@ -174,7 +174,7 @@ def get_preds_file(cur_user, cur_model):
|
|
174 |
|
175 |
# Reports
|
176 |
def get_reports_file(cur_user, cur_model):
|
177 |
-
return f"./data/output/{cur_user}/{cur_model}/reports.
|
178 |
|
179 |
########################################
|
180 |
# General utils
|
@@ -236,14 +236,14 @@ def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
|
|
236 |
return (bar + rule).interactive()
|
237 |
|
238 |
# Generates the summary plot across all topics for the user
|
239 |
-
def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median"):
|
240 |
# Your perf (calculate using model and testset)
|
241 |
preds_file = get_preds_file(cur_user, cur_model)
|
242 |
with open(preds_file, "rb") as f:
|
243 |
preds_df = pickle.load(f)
|
244 |
|
245 |
chart_file = get_chart_file(cur_user, cur_model)
|
246 |
-
if os.path.isfile(chart_file):
|
247 |
# Read from file if it exists
|
248 |
with open(chart_file, "r") as f:
|
249 |
topic_overview_plot_json = json.load(f)
|
@@ -254,6 +254,9 @@ def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD
|
|
254 |
elif topic_vis_method == "mean":
|
255 |
preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
|
256 |
topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model)
|
|
|
|
|
|
|
257 |
|
258 |
return {
|
259 |
"topic_overview_plot_json": json.loads(topic_overview_plot_json),
|
@@ -345,7 +348,7 @@ def fetch_existing_data(user, model_name):
|
|
345 |
# - topic: topic to train on (used when tuning for a specific topic)
|
346 |
def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False):
|
347 |
# Check if there is previously-labeled data; if so, combine it with this data
|
348 |
-
labeled_df = format_labeled_data(ratings) # Treat ratings as full batch of all ratings
|
349 |
ratings_prev = None
|
350 |
|
351 |
# Filter out rows with "unsure" (-1)
|
@@ -362,7 +365,7 @@ def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug
|
|
362 |
label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
|
363 |
with open(label_file, "rb") as f:
|
364 |
ratings_prev = pickle.load(f)
|
365 |
-
labeled_df_prev = format_labeled_data(ratings_prev)
|
366 |
labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
|
367 |
ratings.update(ratings_prev) # append old ratings to ratings
|
368 |
labeled_df = pd.concat([labeled_df_prev, labeled_df])
|
@@ -377,23 +380,26 @@ def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug
|
|
377 |
cur_model, _, _, _ = train_user_model(ratings_df=labeled_df)
|
378 |
|
379 |
# Compute performance metrics
|
380 |
-
mae, mse, rmse, avg_diff = users_perf(cur_model)
|
381 |
# Save performance metrics
|
382 |
perf_file = get_perf_file(user, model_name)
|
383 |
with open(perf_file, "wb") as f:
|
384 |
pickle.dump((mae, mse, rmse, avg_diff), f)
|
385 |
|
386 |
# Pre-compute predictions for full dataset
|
387 |
-
cur_preds_df = get_preds_df(cur_model, [
|
388 |
# Save pre-computed predictions
|
389 |
preds_file = get_preds_file(user, model_name)
|
390 |
with open(preds_file, "wb") as f:
|
391 |
pickle.dump(cur_preds_df, f)
|
392 |
|
|
|
|
|
|
|
393 |
ratings_prev = ratings
|
394 |
return mae, mse, rmse, avg_diff, ratings_prev
|
395 |
|
396 |
-
def format_labeled_data(ratings, worker_id
|
397 |
all_rows = []
|
398 |
for comment, rating in ratings.items():
|
399 |
comment_id = comments_to_ids[comment]
|
@@ -403,7 +409,7 @@ def format_labeled_data(ratings, worker_id="A"):
|
|
403 |
df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
|
404 |
return df
|
405 |
|
406 |
-
def users_perf(model, sys_eval_df=sys_eval_df
|
407 |
# Load the full empty dataset
|
408 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
409 |
empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
|
@@ -423,7 +429,7 @@ def users_perf(model, sys_eval_df=sys_eval_df, worker_id="A"):
|
|
423 |
df.dropna(subset = ["pred"], inplace=True)
|
424 |
df["rating"] = df.rating.astype("int32")
|
425 |
|
426 |
-
perf_metrics = get_overall_perf(df,
|
427 |
return perf_metrics
|
428 |
|
429 |
def get_overall_perf(preds_df, user_id):
|
@@ -565,24 +571,24 @@ def plot_train_perf_results(user, model_name, mae):
|
|
565 |
width=500,
|
566 |
)
|
567 |
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
|
572 |
plot_dim_width = 500
|
573 |
domain_min = 0.0
|
574 |
domain_max = 2.0
|
575 |
bkgd = alt.Chart(pd.DataFrame({
|
576 |
-
"start": [
|
577 |
-
"stop": [domain_max,
|
578 |
-
"bkgd": ["Needs improvement
|
579 |
})).mark_rect(opacity=0.2).encode(
|
580 |
-
y=alt.Y("start:Q", scale=alt.Scale(domain=[0, domain_max])),
|
581 |
-
y2=alt.Y2("stop:Q"),
|
582 |
x=alt.value(0),
|
583 |
x2=alt.value(plot_dim_width),
|
584 |
color=alt.Color("bkgd:O", scale=alt.Scale(
|
585 |
-
domain=["Needs improvement
|
586 |
range=["red", "yellow", "green"]),
|
587 |
title="How good is your MAE?"
|
588 |
)
|
@@ -590,12 +596,12 @@ def plot_train_perf_results(user, model_name, mae):
|
|
590 |
|
591 |
plot = (bkgd + chart).properties(width=plot_dim_width).resolve_scale(color='independent')
|
592 |
mae_status = None
|
593 |
-
if mae <
|
594 |
-
mae_status = "Your MAE is in the <b>Good</b> range
|
595 |
-
elif mae <
|
596 |
-
mae_status = "Your MAE is in the <b>Okay</b> range
|
597 |
else:
|
598 |
-
mae_status = "Your MAE is in the <b>Needs improvement</b> range
|
599 |
return plot, mae_status
|
600 |
|
601 |
########################################
|
@@ -724,7 +730,7 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
|
|
724 |
df = df[df["topic_id"] < n_topics]
|
725 |
|
726 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
727 |
-
df = df[df["user_id"] ==
|
728 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
729 |
df["threshold"] = [threshold for r in df[sys_col].tolist()]
|
730 |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
@@ -824,21 +830,15 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
|
|
824 |
)
|
825 |
|
826 |
plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
|
827 |
-
|
828 |
-
# Save to file
|
829 |
-
chart_file = get_chart_file(cur_user, cur_model)
|
830 |
-
with open(chart_file, "w") as f:
|
831 |
-
json.dump(plot, f)
|
832 |
-
|
833 |
return plot
|
834 |
|
835 |
# Plots cluster results histogram (each block is a comment), but *without* a model
|
836 |
# as a point of reference (in contrast to plot_overall_vis_cluster)
|
837 |
-
def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
838 |
df = preds_df.copy().reset_index()
|
839 |
|
840 |
df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True)
|
841 |
-
df = df[df["user_id"] ==
|
842 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
843 |
df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()]
|
844 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
@@ -930,11 +930,11 @@ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS,
|
|
930 |
return final_plot, df
|
931 |
|
932 |
# Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
|
933 |
-
def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
934 |
df = preds_df.copy().reset_index()
|
935 |
|
936 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
937 |
-
df = df[df["user_id"] ==
|
938 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
939 |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
940 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
|
|
139 |
# Charts
|
140 |
def get_chart_file(cur_user, cur_model):
|
141 |
chart_dir = f"./data/output/{cur_user}/{cur_model}"
|
142 |
+
return os.path.join(chart_dir, f"chart_overall_vis.json")
|
143 |
|
144 |
# Labels
|
145 |
def get_label_dir(cur_user, cur_model):
|
|
|
174 |
|
175 |
# Reports
|
176 |
def get_reports_file(cur_user, cur_model):
|
177 |
+
return f"./data/output/{cur_user}/{cur_model}/reports.json"
|
178 |
|
179 |
########################################
|
180 |
# General utils
|
|
|
236 |
return (bar + rule).interactive()
|
237 |
|
238 |
# Generates the summary plot across all topics for the user
|
239 |
+
def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median", use_cache=True):
|
240 |
# Your perf (calculate using model and testset)
|
241 |
preds_file = get_preds_file(cur_user, cur_model)
|
242 |
with open(preds_file, "rb") as f:
|
243 |
preds_df = pickle.load(f)
|
244 |
|
245 |
chart_file = get_chart_file(cur_user, cur_model)
|
246 |
+
if use_cache and os.path.isfile(chart_file):
|
247 |
# Read from file if it exists
|
248 |
with open(chart_file, "r") as f:
|
249 |
topic_overview_plot_json = json.load(f)
|
|
|
254 |
elif topic_vis_method == "mean":
|
255 |
preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
|
256 |
topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model)
|
257 |
+
# Save to file
|
258 |
+
with open(chart_file, "w") as f:
|
259 |
+
json.dump(topic_overview_plot_json, f)
|
260 |
|
261 |
return {
|
262 |
"topic_overview_plot_json": json.loads(topic_overview_plot_json),
|
|
|
348 |
# - topic: topic to train on (used when tuning for a specific topic)
|
349 |
def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False):
|
350 |
# Check if there is previously-labeled data; if so, combine it with this data
|
351 |
+
labeled_df = format_labeled_data(ratings, worker_id=user) # Treat ratings as full batch of all ratings
|
352 |
ratings_prev = None
|
353 |
|
354 |
# Filter out rows with "unsure" (-1)
|
|
|
365 |
label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
|
366 |
with open(label_file, "rb") as f:
|
367 |
ratings_prev = pickle.load(f)
|
368 |
+
labeled_df_prev = format_labeled_data(ratings_prev, worker_id=user)
|
369 |
labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
|
370 |
ratings.update(ratings_prev) # append old ratings to ratings
|
371 |
labeled_df = pd.concat([labeled_df_prev, labeled_df])
|
|
|
380 |
cur_model, _, _, _ = train_user_model(ratings_df=labeled_df)
|
381 |
|
382 |
# Compute performance metrics
|
383 |
+
mae, mse, rmse, avg_diff = users_perf(cur_model, worker_id=user)
|
384 |
# Save performance metrics
|
385 |
perf_file = get_perf_file(user, model_name)
|
386 |
with open(perf_file, "wb") as f:
|
387 |
pickle.dump((mae, mse, rmse, avg_diff), f)
|
388 |
|
389 |
# Pre-compute predictions for full dataset
|
390 |
+
cur_preds_df = get_preds_df(cur_model, [user], sys_eval_df=ratings_df_full)
|
391 |
# Save pre-computed predictions
|
392 |
preds_file = get_preds_file(user, model_name)
|
393 |
with open(preds_file, "wb") as f:
|
394 |
pickle.dump(cur_preds_df, f)
|
395 |
|
396 |
+
# Replace cached summary plot if it exists
|
397 |
+
show_overall_perf(cur_model=model_name, error_type="Both", cur_user=user, use_cache=False)
|
398 |
+
|
399 |
ratings_prev = ratings
|
400 |
return mae, mse, rmse, avg_diff, ratings_prev
|
401 |
|
402 |
+
def format_labeled_data(ratings, worker_id):
|
403 |
all_rows = []
|
404 |
for comment, rating in ratings.items():
|
405 |
comment_id = comments_to_ids[comment]
|
|
|
409 |
df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
|
410 |
return df
|
411 |
|
412 |
+
def users_perf(model, worker_id, sys_eval_df=sys_eval_df):
|
413 |
# Load the full empty dataset
|
414 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
415 |
empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
|
|
|
429 |
df.dropna(subset = ["pred"], inplace=True)
|
430 |
df["rating"] = df.rating.astype("int32")
|
431 |
|
432 |
+
perf_metrics = get_overall_perf(df, worker_id) # mae, mse, rmse, avg_diff
|
433 |
return perf_metrics
|
434 |
|
435 |
def get_overall_perf(preds_df, user_id):
|
|
|
571 |
width=500,
|
572 |
)
|
573 |
|
574 |
+
# Manually set for now
|
575 |
+
mae_good = 1.0
|
576 |
+
mae_okay = 1.2
|
577 |
|
578 |
plot_dim_width = 500
|
579 |
domain_min = 0.0
|
580 |
domain_max = 2.0
|
581 |
bkgd = alt.Chart(pd.DataFrame({
|
582 |
+
"start": [mae_okay, mae_good, domain_min],
|
583 |
+
"stop": [domain_max, mae_okay, mae_good],
|
584 |
+
"bkgd": ["Needs improvement", "Okay", "Good"],
|
585 |
})).mark_rect(opacity=0.2).encode(
|
586 |
+
y=alt.Y("start:Q", scale=alt.Scale(domain=[0, domain_max]), title=""),
|
587 |
+
y2=alt.Y2("stop:Q", title="Performance (MAE)"),
|
588 |
x=alt.value(0),
|
589 |
x2=alt.value(plot_dim_width),
|
590 |
color=alt.Color("bkgd:O", scale=alt.Scale(
|
591 |
+
domain=["Needs improvement", "Okay", "Good"],
|
592 |
range=["red", "yellow", "green"]),
|
593 |
title="How good is your MAE?"
|
594 |
)
|
|
|
596 |
|
597 |
plot = (bkgd + chart).properties(width=plot_dim_width).resolve_scale(color='independent')
|
598 |
mae_status = None
|
599 |
+
if mae < mae_good:
|
600 |
+
mae_status = "Your MAE is in the <b>Good</b> range. Your model looks ready to go."
|
601 |
+
elif mae < mae_okay:
|
602 |
+
mae_status = "Your MAE is in the <b>Okay</b> range. Your model can be used, but you can provide additional labels to improve it."
|
603 |
else:
|
604 |
+
mae_status = "Your MAE is in the <b>Needs improvement</b> range. Your model may need additional labels to improve."
|
605 |
return plot, mae_status
|
606 |
|
607 |
########################################
|
|
|
730 |
df = df[df["topic_id"] < n_topics]
|
731 |
|
732 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
733 |
+
df = df[df["user_id"] == cur_user].sort_values(by=["item_id"]).reset_index()
|
734 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
735 |
df["threshold"] = [threshold for r in df[sys_col].tolist()]
|
736 |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
|
|
830 |
)
|
831 |
|
832 |
plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
return plot
|
834 |
|
835 |
# Plots cluster results histogram (each block is a comment), but *without* a model
|
836 |
# as a point of reference (in contrast to plot_overall_vis_cluster)
|
837 |
+
def plot_overall_vis_cluster_no_model(cur_user, preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
838 |
df = preds_df.copy().reset_index()
|
839 |
|
840 |
df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True)
|
841 |
+
df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index()
|
842 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
843 |
df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()]
|
844 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
|
|
930 |
return final_plot, df
|
931 |
|
932 |
# Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
|
933 |
+
def plot_overall_vis_cluster(cur_user, preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
934 |
df = preds_df.copy().reset_index()
|
935 |
|
936 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
937 |
+
df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index(drop=True)
|
938 |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
939 |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
940 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
indie_label_svelte/src/Auditing.svelte
CHANGED
@@ -158,12 +158,15 @@
|
|
158 |
|
159 |
function handleAuditButton() {
|
160 |
model_chosen.update((value) => personalized_model);
|
161 |
-
|
|
|
|
|
|
|
162 |
}
|
163 |
|
164 |
-
async function getAudit() {
|
165 |
let req_params = {
|
166 |
-
pers_model:
|
167 |
perf_metric: "avg_diff",
|
168 |
breakdown_sort: "difference",
|
169 |
n_topics: 10,
|
@@ -179,18 +182,18 @@
|
|
179 |
}
|
180 |
|
181 |
function handleClusterButton() {
|
182 |
-
promise_cluster = getCluster();
|
183 |
}
|
184 |
|
185 |
-
async function getCluster() {
|
186 |
-
if (
|
187 |
return null;
|
188 |
}
|
189 |
let req_params = {
|
190 |
cluster: topic,
|
191 |
topic_df_ids: [],
|
192 |
cur_user: cur_user,
|
193 |
-
pers_model:
|
194 |
example_sort: "descending", // TEMP
|
195 |
comparison_group: "status_quo", // TEMP
|
196 |
search_type: "cluster",
|
|
|
158 |
|
159 |
function handleAuditButton() {
|
160 |
model_chosen.update((value) => personalized_model);
|
161 |
+
if (personalized_model == "" || personalized_model == undefined) {
|
162 |
+
return;
|
163 |
+
}
|
164 |
+
promise = getAudit(personalized_model);
|
165 |
}
|
166 |
|
167 |
+
async function getAudit(pers_model) {
|
168 |
let req_params = {
|
169 |
+
pers_model: pers_model,
|
170 |
perf_metric: "avg_diff",
|
171 |
breakdown_sort: "difference",
|
172 |
n_topics: 10,
|
|
|
182 |
}
|
183 |
|
184 |
function handleClusterButton() {
|
185 |
+
promise_cluster = getCluster(personalized_model);
|
186 |
}
|
187 |
|
188 |
+
async function getCluster(pers_model) {
|
189 |
+
if (pers_model == "" || pers_model == undefined) {
|
190 |
return null;
|
191 |
}
|
192 |
let req_params = {
|
193 |
cluster: topic,
|
194 |
topic_df_ids: [],
|
195 |
cur_user: cur_user,
|
196 |
+
pers_model: pers_model,
|
197 |
example_sort: "descending", // TEMP
|
198 |
comparison_group: "status_quo", // TEMP
|
199 |
search_type: "cluster",
|
indie_label_svelte/src/CommentTable.svelte
CHANGED
@@ -88,11 +88,14 @@
|
|
88 |
user: cur_user,
|
89 |
};
|
90 |
let params = new URLSearchParams(req_params).toString();
|
91 |
-
const
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
return data;
|
97 |
}
|
98 |
</script>
|
|
|
88 |
user: cur_user,
|
89 |
};
|
90 |
let params = new URLSearchParams(req_params).toString();
|
91 |
+
const data = await fetch("./get_personalized_model?" + params)
|
92 |
+
.then((r) => r.text())
|
93 |
+
.then(function (text) {
|
94 |
+
let data = JSON.parse(text);
|
95 |
+
to_label = data["ratings_prev"];
|
96 |
+
model_chosen.update((value) => model_name);
|
97 |
+
return data;
|
98 |
+
});
|
99 |
return data;
|
100 |
}
|
101 |
</script>
|
indie_label_svelte/src/ModelPerf.svelte
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
<script lang="ts">
|
2 |
import { VegaLite } from "svelte-vega";
|
3 |
import type { View } from "svelte-vega";
|
4 |
-
|
5 |
import LayoutGrid, { Cell } from "@smui/layout-grid";
|
6 |
-
import Card, { Content } from '@smui/card';
|
7 |
|
8 |
export let data;
|
9 |
|
@@ -13,64 +11,25 @@
|
|
13 |
];
|
14 |
let perf_plot_view: View;
|
15 |
|
16 |
-
// let perf_plot2_spec = data["perf_plot2_json"];
|
17 |
-
// let perf_plot2_data = perf_plot2_spec["datasets"][perf_plot2_spec["data"]["name"]];
|
18 |
-
// let perf_plot2_view: View;
|
19 |
</script>
|
20 |
|
21 |
<div>
|
22 |
<h6>Your Model Performance</h6>
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
<ul>
|
38 |
-
<li>{@html data["mae_status"]}</li>
|
39 |
-
<!-- <li>
|
40 |
-
This is <b>better</b> (lower) than the average MAE for other users, so your model appears to <b>better capture</b> your views than the typical user model.
|
41 |
-
</li> -->
|
42 |
-
</ul>
|
43 |
-
</li>
|
44 |
-
</ul>
|
45 |
-
</Card>
|
46 |
-
</div>
|
47 |
-
</Cell>
|
48 |
-
</LayoutGrid>
|
49 |
<div>
|
50 |
-
<!-- Overall -->
|
51 |
-
<!-- <table>
|
52 |
-
<tbody>
|
53 |
-
<tr>
|
54 |
-
<td>
|
55 |
-
<span class="bold">Mean Absolute Error (MAE)</span><br>
|
56 |
-
|
57 |
-
</td>
|
58 |
-
<td>
|
59 |
-
<span class="bold-large">{data["mae"]}</span>
|
60 |
-
</td>
|
61 |
-
</tr>
|
62 |
-
<tr>
|
63 |
-
<td>
|
64 |
-
<span class="bold">Average rating difference</span><br>
|
65 |
-
This metric indicates the average difference between your model's rating and your actual rating on a held-out set of comments.
|
66 |
-
</td>
|
67 |
-
<td>
|
68 |
-
<span class="bold-large">{data["avg_diff"]}</span>
|
69 |
-
</td>
|
70 |
-
</tr>
|
71 |
-
</tbody>
|
72 |
-
</table> -->
|
73 |
-
|
74 |
<!-- Performance visualization -->
|
75 |
<div>
|
76 |
<VegaLite {perf_plot_data} spec={perf_plot_spec} bind:view={perf_plot_view}/>
|
|
|
1 |
<script lang="ts">
|
2 |
import { VegaLite } from "svelte-vega";
|
3 |
import type { View } from "svelte-vega";
|
|
|
4 |
import LayoutGrid, { Cell } from "@smui/layout-grid";
|
|
|
5 |
|
6 |
export let data;
|
7 |
|
|
|
11 |
];
|
12 |
let perf_plot_view: View;
|
13 |
|
|
|
|
|
|
|
14 |
</script>
|
15 |
|
16 |
<div>
|
17 |
<h6>Your Model Performance</h6>
|
18 |
+
<ul>
|
19 |
+
<li>
|
20 |
+
The <b>Mean Absolute Error (MAE)</b> metric indicates the average absolute difference <br>between your model's rating and your actual rating on a held-out set of comments.
|
21 |
+
</li>
|
22 |
+
<li>
|
23 |
+
You want your model to have a <b>lower</b> MAE (indicating <b>less error</b>).
|
24 |
+
</li>
|
25 |
+
<li>
|
26 |
+
<b>Your current MAE: {data["mae"]}</b>
|
27 |
+
<ul>
|
28 |
+
<li>{@html data["mae_status"]}</li>
|
29 |
+
</ul>
|
30 |
+
</li>
|
31 |
+
</ul>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
<div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
<!-- Performance visualization -->
|
34 |
<div>
|
35 |
<VegaLite {perf_plot_data} spec={perf_plot_spec} bind:view={perf_plot_view}/>
|
server.py
CHANGED
@@ -166,10 +166,10 @@ def get_cluster_results(debug=DEBUG):
|
|
166 |
# Prepare overview plot for the cluster
|
167 |
if use_model:
|
168 |
# Display results with the model as a reference point
|
169 |
-
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
|
170 |
else:
|
171 |
# Display results without a model
|
172 |
-
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(topic_df, n_comments=500)
|
173 |
|
174 |
cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, use_model=use_model) # New version of cluster comment table
|
175 |
|
@@ -428,7 +428,7 @@ def get_reports():
|
|
428 |
else:
|
429 |
# Load from pickle file
|
430 |
with open(reports_file, "rb") as f:
|
431 |
-
reports =
|
432 |
|
433 |
results = {
|
434 |
"reports": reports,
|
@@ -538,7 +538,7 @@ def get_personal_scaffold(cur_user, model, topic_vis_method, n_topics=200, n=5):
|
|
538 |
preds_df = pickle.load(f)
|
539 |
system_preds_df = utils.get_system_preds_df()
|
540 |
preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
|
541 |
-
preds_df_mod = preds_df_mod[preds_df_mod["user_id"] ==
|
542 |
preds_df_mod = preds_df_mod[preds_df_mod["topic_id"] < n_topics]
|
543 |
|
544 |
if topic_vis_method == "median":
|
@@ -643,8 +643,8 @@ def save_reports():
|
|
643 |
|
644 |
# Save reports for current user to file
|
645 |
reports_file = utils.get_reports_file(cur_user, model)
|
646 |
-
with open(reports_file, "
|
647 |
-
|
648 |
|
649 |
results = {
|
650 |
"status": "success",
|
|
|
166 |
# Prepare overview plot for the cluster
|
167 |
if use_model:
|
168 |
# Display results with the model as a reference point
|
169 |
+
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(cur_user, topic_df, error_type=error_type, n_comments=500)
|
170 |
else:
|
171 |
# Display results without a model
|
172 |
+
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(cur_user, topic_df, n_comments=500)
|
173 |
|
174 |
cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, use_model=use_model) # New version of cluster comment table
|
175 |
|
|
|
428 |
else:
|
429 |
# Load from pickle file
|
430 |
with open(reports_file, "rb") as f:
|
431 |
+
reports = json.load(f)
|
432 |
|
433 |
results = {
|
434 |
"reports": reports,
|
|
|
538 |
preds_df = pickle.load(f)
|
539 |
system_preds_df = utils.get_system_preds_df()
|
540 |
preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
|
541 |
+
preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == cur_user].sort_values(by=["item_id"]).reset_index()
|
542 |
preds_df_mod = preds_df_mod[preds_df_mod["topic_id"] < n_topics]
|
543 |
|
544 |
if topic_vis_method == "median":
|
|
|
643 |
|
644 |
# Save reports for current user to file
|
645 |
reports_file = utils.get_reports_file(cur_user, model)
|
646 |
+
with open(reports_file, "w", encoding ='utf8') as f:
|
647 |
+
json.dump(reports, f)
|
648 |
|
649 |
results = {
|
650 |
"status": "success",
|