Spaces:
Runtime error
Runtime error
Michelle Lam
commited on
Commit
·
b04690b
1
Parent(s):
70ab0be
Adapts labeling and auditing for single-session flow. Removes unused functionality throughout.
Browse files- Removes full model caching.
- Cleans up comments_grouped_full_topic_cat to system_preds_df; pre-processes data and renames+refactors merging operations to avoid confusion.
- Removes unused functionality (personal clustering, comparing against others' performance, nearest neighbor search).
- Moves constant data to data/input/ directory.
Adds automatically generated usernames. Removes username selection and shared user store. Removes Results and Study Links views. Removes AppOld component.
- .gitignore +5 -0
- audit_utils.py +129 -652
- indie_label_svelte/src/Auditing.svelte +9 -17
- indie_label_svelte/src/CommentTable.svelte +8 -2
- indie_label_svelte/src/Hunch.svelte +0 -26
- indie_label_svelte/src/HypothesisPanel.svelte +7 -1
- indie_label_svelte/src/IterativeClustering.svelte +0 -164
- indie_label_svelte/src/KeywordSearch.svelte +0 -3
- indie_label_svelte/src/Labeling.svelte +2 -1
- server.py +56 -137
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.DS_Store
|
3 |
+
data/
|
4 |
+
data.zip
|
5 |
+
test_nbs/
|
audit_utils.py
CHANGED
@@ -40,66 +40,48 @@ module_dir = "./"
|
|
40 |
perf_dir = f"data/perf/"
|
41 |
|
42 |
# # TEMP reset
|
43 |
-
# with open(os.path.join(module_dir, "./data/all_model_names.pkl"), "wb") as f:
|
44 |
-
# all_model_names = []
|
45 |
-
# pickle.dump(all_model_names, f)
|
46 |
# with open(f"./data/users_to_models.pkl", "wb") as f:
|
47 |
# users_to_models = {}
|
48 |
# pickle.dump(users_to_models, f)
|
49 |
|
50 |
-
|
51 |
-
with open(os.path.join(module_dir, "data/ids_to_comments.pkl"), "rb") as f:
|
52 |
ids_to_comments = pickle.load(f)
|
53 |
-
with open(os.path.join(module_dir, "data/comments_to_ids.pkl"), "rb") as f:
|
54 |
comments_to_ids = pickle.load(f)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
sys_eval_df = pd.read_pickle(os.path.join(module_dir, "data/split_data/sys_eval_df.pkl"))
|
59 |
-
train_df = pd.read_pickle(os.path.join(module_dir, "data/split_data/train_df.pkl"))
|
60 |
train_df_ids = train_df["item_id"].unique().tolist()
|
61 |
-
model_eval_df = pd.read_pickle(os.path.join(module_dir, "data/split_data/model_eval_df.pkl"))
|
62 |
-
ratings_df_full = pd.read_pickle(os.path.join(module_dir, "data/ratings_df_full.pkl"))
|
63 |
-
|
64 |
-
worker_info_df = pd.read_pickle("./data/worker_info_df.pkl")
|
65 |
|
66 |
with open(f"./data/users_to_models.pkl", "rb") as f:
|
67 |
users_to_models = pickle.load(f)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
with open("data/perf_1000_tox_cat.pkl", "rb") as f:
|
72 |
-
perf_1000_tox_cat = pickle.load(f)
|
73 |
-
with open("data/perf_1000_tox_severity.pkl", "rb") as f:
|
74 |
-
perf_1000_tox_severity = pickle.load(f)
|
75 |
-
with open("data/user_perf_metrics.pkl", "rb") as f:
|
76 |
-
user_perf_metrics = pickle.load(f)
|
77 |
-
|
78 |
-
topic_ids = comments_grouped_full_topic_cat.topic_id
|
79 |
-
topics = comments_grouped_full_topic_cat.topic
|
80 |
topic_ids_to_topics = {topic_ids[i]: topics[i] for i in range(len(topic_ids))}
|
81 |
topics_to_topic_ids = {topics[i]: topic_ids[i] for i in range(len(topic_ids))}
|
82 |
-
unique_topics_ids = sorted(
|
83 |
unique_topics = [topic_ids_to_topics[topic_id] for topic_id in range(len(topic_ids_to_topics) - 1)]
|
84 |
|
85 |
def get_toxic_threshold():
|
86 |
return TOXIC_THRESHOLD
|
87 |
|
88 |
-
def
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
user_models.sort()
|
96 |
-
return user_models
|
97 |
|
98 |
def get_unique_topics():
|
99 |
return unique_topics
|
100 |
|
101 |
def get_large_clusters(min_n):
|
102 |
-
counts_df =
|
103 |
counts_df = counts_df[counts_df["counts"] >= min_n]
|
104 |
return [topic_ids_to_topics[t_id] for t_id in sorted(counts_df["topic_id"].tolist()[1:])]
|
105 |
|
@@ -137,32 +119,8 @@ readable_to_internal = {
|
|
137 |
}
|
138 |
internal_to_readable = {v: k for k, v in readable_to_internal.items()}
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
model = SentenceTransformer(model_name)
|
143 |
-
with open("./data/comments.pkl", "rb") as f:
|
144 |
-
comments = pickle.load(f)
|
145 |
-
embeddings = torch.load("./data/embeddings/21_10_embeddings.pt")
|
146 |
-
|
147 |
-
# Perspective API recalibration
|
148 |
-
def recalib_v1(s):
|
149 |
-
# convert Perspective score to 0-4 toxicity score
|
150 |
-
# map 0 persp to 0 (not at all toxic); 0.5 persp to 1 (slightly toxic), 1.0 persp to 4 (extremely toxic)
|
151 |
-
if s < 0.5:
|
152 |
-
return (s * 2.)
|
153 |
-
else:
|
154 |
-
return ((s - 0.5) * 6.) + 1
|
155 |
-
|
156 |
-
def recalib_v2(s):
|
157 |
-
# convert Perspective score to 0-4 toxicity score
|
158 |
-
# just 4x the perspective score
|
159 |
-
return (s * 4.)
|
160 |
-
|
161 |
-
comments_grouped_full_topic_cat["rating_avg_orig"] = comments_grouped_full_topic_cat["rating"]
|
162 |
-
comments_grouped_full_topic_cat["rating"] = [recalib_v2(score) for score in comments_grouped_full_topic_cat["persp_score"].tolist()]
|
163 |
-
|
164 |
-
def get_comments_grouped_full_topic_cat():
|
165 |
-
return comments_grouped_full_topic_cat
|
166 |
|
167 |
########################################
|
168 |
# General utils
|
@@ -192,22 +150,6 @@ def my_bootstrap(vals, n_boot, alpha):
|
|
192 |
|
193 |
########################################
|
194 |
# GET_AUDIT utils
|
195 |
-
def other_users_perf(perf_metrics, metric, user_metric, alpha=0.95, n_boot=501):
|
196 |
-
ind = get_metric_ind(metric)
|
197 |
-
|
198 |
-
metric_vals = [metric_vals[ind] for metric_vals in perf_metrics.values()]
|
199 |
-
metric_avg = np.median(metric_vals)
|
200 |
-
|
201 |
-
# Future: use provided sample to perform bootstrap sampling
|
202 |
-
ci_1 = mne.stats.bootstrap_confidence_interval(np.array(metric_vals), ci=alpha, n_bootstraps=n_boot, stat_fun="median")
|
203 |
-
|
204 |
-
bs_samples, ci = my_bootstrap(metric_vals, n_boot, alpha)
|
205 |
-
|
206 |
-
# Get user's percentile
|
207 |
-
percentile = stats.percentileofscore(bs_samples, user_metric)
|
208 |
-
|
209 |
-
return metric_avg, ci, percentile, metric_vals
|
210 |
-
|
211 |
def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
|
212 |
hist, bin_edges = np.histogram(other_metric_vals, bins=n_bins, density=False)
|
213 |
data = pd.DataFrame({
|
@@ -239,395 +181,34 @@ def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
|
|
239 |
|
240 |
return (bar + rule).interactive()
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
other_ci_low = []
|
248 |
-
other_ci_high = []
|
249 |
-
for severity_i in range(len(bin_labels)):
|
250 |
-
metric_others = [metrics[get_metric_ind(perf_metric)] for metrics in perf_1000_tox_severity[severity_i].values() if metrics[get_metric_ind(perf_metric)]]
|
251 |
-
ci_low, ci_high = mne.stats.bootstrap_confidence_interval(np.array(metric_others), ci=ci, n_bootstraps=n_boot, stat_fun='median')
|
252 |
-
metric_other = np.median(metric_others)
|
253 |
-
|
254 |
-
cur_user_df = user_df[user_df["prediction_bin"] == severity_i]
|
255 |
-
y_true_user = cur_user_df.pred.to_numpy() # user's label
|
256 |
-
y_pred = cur_user_df.rating_avg.to_numpy() # system's label (avg)
|
257 |
-
|
258 |
-
if len(y_true_user) > 0:
|
259 |
-
used_bins.append(bin_labels[severity_i])
|
260 |
-
metric_user = calc_metric_user(y_true_user, y_pred, perf_metric)
|
261 |
-
y_user.append(metric_user)
|
262 |
-
y_other.append(metric_other)
|
263 |
-
other_ci_low.append(ci_low)
|
264 |
-
other_ci_high.append(ci_high)
|
265 |
-
|
266 |
-
return y_user, y_other, used_bins, other_ci_low, other_ci_high
|
267 |
-
|
268 |
-
def get_topic_bins(perf_metric, user_df, other_dfs, n_topics, ci=0.95, n_boot=501):
|
269 |
-
# Note: not using other_dfs anymore
|
270 |
-
y_user = []
|
271 |
-
y_other = []
|
272 |
-
used_bins = []
|
273 |
-
other_ci_low = []
|
274 |
-
other_ci_high = []
|
275 |
-
selected_topics = unique_topics_ids[1:(n_topics + 1)]
|
276 |
-
|
277 |
-
for topic_id in selected_topics:
|
278 |
-
cur_topic = topic_ids_to_topics[topic_id]
|
279 |
-
metric_others = [metrics[get_metric_ind(perf_metric)] for metrics in perf_1000_topics[topic_id].values() if metrics[get_metric_ind(perf_metric)]]
|
280 |
-
ci_low, ci_high = mne.stats.bootstrap_confidence_interval(np.array(metric_others), ci=ci, n_bootstraps=n_boot, stat_fun='median')
|
281 |
-
metric_other = np.median(metric_others)
|
282 |
-
|
283 |
-
cur_user_df = user_df[user_df["topic"] == cur_topic]
|
284 |
-
y_true_user = cur_user_df.pred.to_numpy() # user's label
|
285 |
-
y_pred = cur_user_df.rating_avg.to_numpy() # system's label (avg)
|
286 |
-
|
287 |
-
if len(y_true_user) > 0:
|
288 |
-
used_bins.append(cur_topic)
|
289 |
-
metric_user = calc_metric_user(y_true_user, y_pred, perf_metric)
|
290 |
-
y_user.append(metric_user)
|
291 |
-
y_other.append(metric_other)
|
292 |
-
other_ci_low.append(ci_low)
|
293 |
-
other_ci_high.append(ci_high)
|
294 |
-
|
295 |
-
return y_user, y_other, used_bins, other_ci_low, other_ci_high
|
296 |
-
|
297 |
-
def calc_metric_user(y_true_user, y_pred, perf_metric):
|
298 |
-
if perf_metric == "MAE":
|
299 |
-
metric_user = mean_absolute_error(y_true_user, y_pred)
|
300 |
-
|
301 |
-
elif perf_metric == "MSE":
|
302 |
-
metric_user = mean_squared_error(y_true_user, y_pred)
|
303 |
-
|
304 |
-
elif perf_metric == "RMSE":
|
305 |
-
metric_user = mean_squared_error(y_true_user, y_pred, squared=False)
|
306 |
-
|
307 |
-
elif perf_metric == "avg_diff":
|
308 |
-
metric_user = np.mean(y_true_user - y_pred)
|
309 |
-
|
310 |
-
return metric_user
|
311 |
-
|
312 |
-
def get_toxicity_category_bins(perf_metric, user_df, other_dfs, threshold=0.5, ci=0.95, n_boot=501):
|
313 |
-
# Note: not using other_dfs anymore; threshold from pre-calculation is 0.5
|
314 |
-
cat_cols = ["is_profane_frac", "is_threat_frac", "is_identity_attack_frac", "is_insult_frac", "is_sexual_harassment_frac"]
|
315 |
-
cat_labels = ["Profanity", "Threats", "Identity Attacks", "Insults", "Sexual Harassment"]
|
316 |
-
y_user = []
|
317 |
-
y_other = []
|
318 |
-
used_bins = []
|
319 |
-
other_ci_low = []
|
320 |
-
other_ci_high = []
|
321 |
-
for i, cur_col_name in enumerate(cat_cols):
|
322 |
-
metric_others = [metrics[get_metric_ind(perf_metric)] for metrics in perf_1000_tox_cat[cur_col_name].values() if metrics[get_metric_ind(perf_metric)]]
|
323 |
-
ci_low, ci_high = mne.stats.bootstrap_confidence_interval(np.array(metric_others), ci=ci, n_bootstraps=n_boot, stat_fun='median')
|
324 |
-
metric_other = np.median(metric_others)
|
325 |
-
|
326 |
-
# Filter to rows where a comment received an average label >= the provided threshold for the category
|
327 |
-
cur_user_df = user_df[user_df[cur_col_name] >= threshold]
|
328 |
-
y_true_user = cur_user_df.pred.to_numpy() # user's label
|
329 |
-
y_pred = cur_user_df.rating_avg.to_numpy() # system's label (avg)
|
330 |
-
|
331 |
-
if len(y_true_user) > 0:
|
332 |
-
used_bins.append(cat_labels[i])
|
333 |
-
metric_user = calc_metric_user(y_true_user, y_pred, perf_metric)
|
334 |
-
y_user.append(metric_user)
|
335 |
-
y_other.append(metric_other)
|
336 |
-
other_ci_low.append(ci_low)
|
337 |
-
other_ci_high.append(ci_high)
|
338 |
-
|
339 |
-
return y_user, y_other, used_bins, other_ci_low, other_ci_high
|
340 |
-
|
341 |
-
def plot_class_cond_results(preds_df, breakdown_axis, perf_metric, other_ids, sort_bars, n_topics, worker_id="A"):
|
342 |
-
# Note: preds_df already has binned results
|
343 |
-
# Prepare dfs
|
344 |
-
user_df = preds_df[preds_df.user_id == worker_id].sort_values(by=["item_id"]).reset_index()
|
345 |
-
other_dfs = [preds_df[preds_df.user_id == other_id].sort_values(by=["item_id"]).reset_index() for other_id in other_ids]
|
346 |
-
|
347 |
-
if breakdown_axis == "toxicity_severity":
|
348 |
-
y_user, y_other, used_bins, other_ci_low, other_ci_high = get_toxicity_severity_bins(perf_metric, user_df, other_dfs)
|
349 |
-
elif breakdown_axis == "topic":
|
350 |
-
y_user, y_other, used_bins, other_ci_low, other_ci_high = get_topic_bins(perf_metric, user_df, other_dfs, n_topics)
|
351 |
-
elif breakdown_axis == "toxicity_category":
|
352 |
-
y_user, y_other, used_bins, other_ci_low, other_ci_high = get_toxicity_category_bins(perf_metric, user_df, other_dfs)
|
353 |
-
|
354 |
-
diffs = list(np.array(y_user) - np.array(y_other))
|
355 |
-
|
356 |
-
# Generate bar chart
|
357 |
-
data = pd.DataFrame({
|
358 |
-
"metric_val": y_user + y_other,
|
359 |
-
"Labeler": ["You" for _ in range(len(y_user))] + ["Other users" for _ in range(len(y_user))],
|
360 |
-
"used_bins": used_bins + used_bins,
|
361 |
-
"diffs": diffs + diffs,
|
362 |
-
"lower_cis": y_user + other_ci_low,
|
363 |
-
"upper_cis": y_user + other_ci_high,
|
364 |
-
})
|
365 |
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
y_axis = alt.Y("metric_val:Q", title=internal_to_readable[perf_metric])
|
373 |
-
if sort_bars:
|
374 |
-
col_content = alt.Column("used_bins:O", sort=alt.EncodingSortField(field="diffs", op="mean", order='descending'))
|
375 |
-
else:
|
376 |
-
col_content = alt.Column("used_bins:O")
|
377 |
-
|
378 |
-
if n_topics is not None and n_topics > 10:
|
379 |
-
# Change to horizontal bar chart
|
380 |
-
bar = base.mark_bar(lineBreak="_").encode(
|
381 |
-
y=x_axis,
|
382 |
-
x=y_axis,
|
383 |
-
color=alt.Color("Labeler:O", scale=alt.Scale(domain=color_domain, range=color_range)),
|
384 |
-
tooltip=[
|
385 |
-
alt.Tooltip('Labeler:O', title='Labeler'),
|
386 |
-
alt.Tooltip('metric_val:Q', title=perf_metric, format=".3f"),
|
387 |
-
]
|
388 |
-
)
|
389 |
-
error_bars = base.mark_errorbar().encode(
|
390 |
-
y=x_axis,
|
391 |
-
x = alt.X("lower_cis:Q", title=internal_to_readable[perf_metric]),
|
392 |
-
x2 = alt.X2("upper_cis:Q", title=None),
|
393 |
-
tooltip=[
|
394 |
-
alt.Tooltip('lower_cis:Q', title='Lower CI', format=".3f"),
|
395 |
-
alt.Tooltip('upper_cis:Q', title='Upper CI', format=".3f"),
|
396 |
-
]
|
397 |
-
)
|
398 |
-
combined = alt.layer(
|
399 |
-
bar, error_bars, data=data
|
400 |
-
).facet(
|
401 |
-
row=col_content
|
402 |
-
).properties(
|
403 |
-
title=chart_title,
|
404 |
-
).interactive()
|
405 |
else:
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
alt.Tooltip('Labeler:O', title='Labeler'),
|
412 |
-
alt.Tooltip('metric_val:Q', title=perf_metric, format=".3f"),
|
413 |
-
]
|
414 |
-
)
|
415 |
-
error_bars = base.mark_errorbar().encode(
|
416 |
-
x=x_axis,
|
417 |
-
y = alt.Y("lower_cis:Q", title=internal_to_readable[perf_metric]),
|
418 |
-
y2 = alt.Y2("upper_cis:Q", title=None),
|
419 |
-
tooltip=[
|
420 |
-
alt.Tooltip('lower_cis:Q', title='Lower CI', format=".3f"),
|
421 |
-
alt.Tooltip('upper_cis:Q', title='Upper CI', format=".3f"),
|
422 |
-
]
|
423 |
-
)
|
424 |
-
combined = alt.layer(
|
425 |
-
bar, error_bars, data=data
|
426 |
-
).facet(
|
427 |
-
column=col_content
|
428 |
-
).properties(
|
429 |
-
title=chart_title,
|
430 |
-
).interactive()
|
431 |
-
|
432 |
-
return combined
|
433 |
-
|
434 |
-
# Generates the summary plot across all topics for the user
|
435 |
-
def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD, breakdown_axis=None, topic_vis_method="median"):
|
436 |
-
# Your perf (calculate using model and testset)
|
437 |
-
breakdown_axis = readable_to_internal[breakdown_axis]
|
438 |
-
|
439 |
-
if breakdown_axis is not None:
|
440 |
-
with open(os.path.join(module_dir, f"data/preds_dfs/{variant}.pkl"), "rb") as f:
|
441 |
-
preds_df = pickle.load(f)
|
442 |
-
|
443 |
-
# Read from file
|
444 |
-
chart_dir = "./data/charts"
|
445 |
-
chart_file = os.path.join(chart_dir, f"{cur_user}_{variant}.pkl")
|
446 |
-
if os.path.isfile(chart_file):
|
447 |
-
with open(chart_file, "r") as f:
|
448 |
-
topic_overview_plot_json = json.load(f)
|
449 |
-
else:
|
450 |
-
preds_df_mod = preds_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg'))
|
451 |
-
if topic_vis_method == "median": # Default
|
452 |
-
preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).median()
|
453 |
-
elif topic_vis_method == "mean":
|
454 |
-
preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).mean()
|
455 |
-
topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_mod_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=variant)
|
456 |
|
457 |
return {
|
458 |
"topic_overview_plot_json": json.loads(topic_overview_plot_json),
|
459 |
}
|
460 |
|
461 |
-
########################################
|
462 |
-
# GET_CLUSTER_RESULTS utils
|
463 |
-
def get_overall_perf3(preds_df, perf_metric, other_ids, worker_id="A"):
|
464 |
-
# Prepare dataset to calculate performance
|
465 |
-
# Note: true is user and pred is system
|
466 |
-
y_true = preds_df[preds_df["user_id"] == worker_id].pred.to_numpy()
|
467 |
-
y_pred_user = preds_df[preds_df["user_id"] == worker_id].rating_avg.to_numpy()
|
468 |
-
|
469 |
-
y_true_others = y_pred_others = [preds_df[preds_df["user_id"] == other_id].pred.to_numpy() for other_id in other_ids]
|
470 |
-
y_pred_others = [preds_df[preds_df["user_id"] == other_id].rating_avg.to_numpy() for other_id in other_ids]
|
471 |
-
|
472 |
-
# Get performance for user's model and for other users
|
473 |
-
if perf_metric == "MAE":
|
474 |
-
user_perf = mean_absolute_error(y_true, y_pred_user)
|
475 |
-
other_perfs = [mean_absolute_error(y_true_others[i], y_pred_others[i]) for i in range(len(y_true_others))]
|
476 |
-
elif perf_metric == "MSE":
|
477 |
-
user_perf = mean_squared_error(y_true, y_pred_user)
|
478 |
-
other_perfs = [mean_squared_error(y_true_others[i], y_pred_others[i]) for i in range(len(y_true_others))]
|
479 |
-
elif perf_metric == "RMSE":
|
480 |
-
user_perf = mean_squared_error(y_true, y_pred_user, squared=False)
|
481 |
-
other_perfs = [mean_squared_error(y_true_others[i], y_pred_others[i], squared=False) for i in range(len(y_true_others))]
|
482 |
-
elif perf_metric == "avg_diff":
|
483 |
-
user_perf = np.mean(y_true - y_pred_user)
|
484 |
-
other_perfs = [np.mean(y_true_others[i] - y_pred_others[i]) for i in range(len(y_true_others))]
|
485 |
-
|
486 |
-
other_perf = np.mean(other_perfs) # average across all other users
|
487 |
-
return user_perf, other_perf
|
488 |
-
|
489 |
-
def style_color_difference(row):
|
490 |
-
full_opacity_diff = 3.
|
491 |
-
pred_user_col = "Your predicted rating"
|
492 |
-
pred_other_col = "Other users' predicted rating"
|
493 |
-
pred_system_col = "Status-quo system rating"
|
494 |
-
diff_user = row[pred_user_col] - row[pred_system_col]
|
495 |
-
diff_other = row[pred_other_col] - row[pred_system_col]
|
496 |
-
red = "234, 133, 125"
|
497 |
-
green = "142, 205, 162"
|
498 |
-
bkgd_user = green if diff_user < 0 else red # red if more toxic; green if less toxic
|
499 |
-
opac_user = min(abs(diff_user / full_opacity_diff), 1.)
|
500 |
-
bkgd_other = green if diff_other < 0 else red # red if more toxic; green if less toxic
|
501 |
-
opac_other = min(abs(diff_other / full_opacity_diff), 1.)
|
502 |
-
return ["", f"background-color: rgba({bkgd_user}, {opac_user});", f"background-color: rgba({bkgd_other}, {opac_other});", "", ""]
|
503 |
-
|
504 |
-
def display_examples_cluster(preds_df, other_ids, num_examples, sort_ascending, worker_id="A"):
|
505 |
-
user_df = preds_df[preds_df.user_id == worker_id].sort_values(by=["item_id"]).reset_index()
|
506 |
-
others_df = preds_df[preds_df.user_id == other_ids[0]]
|
507 |
-
for i in range(1, len(other_ids)):
|
508 |
-
others_df.append(preds_df[preds_df.user_id == other_ids[i]])
|
509 |
-
others_df.groupby(["item_id"]).mean()
|
510 |
-
others_df = others_df.sort_values(by=["item_id"]).reset_index()
|
511 |
-
|
512 |
-
df = pd.merge(user_df, others_df, on="item_id", how="left", suffixes=('_user', '_other'))
|
513 |
-
df["Comment"] = df["comment_user"]
|
514 |
-
df["Your predicted rating"] = df["pred_user"]
|
515 |
-
df["Other users' predicted rating"] = df["pred_other"]
|
516 |
-
df["Status-quo system rating"] = df["rating_avg_user"]
|
517 |
-
df["Status-quo system std dev"] = df["rating_stddev_user"]
|
518 |
-
df = df[["Comment", "Your predicted rating", "Other users' predicted rating", "Status-quo system rating", "Status-quo system std dev"]]
|
519 |
-
|
520 |
-
# Add styling
|
521 |
-
df = df.sort_values(by=['Status-quo system std dev'], ascending=sort_ascending)
|
522 |
-
n_to_sample = np.min([num_examples, len(df)])
|
523 |
-
df = df.sample(n=n_to_sample).reset_index(drop=True)
|
524 |
-
return df.style.apply(style_color_difference, axis=1).render()
|
525 |
-
|
526 |
-
def calc_odds_ratio(df, comparison_group, toxic_threshold=1.5, worker_id="A", debug=False, smoothing_factor=1):
|
527 |
-
if comparison_group == "status_quo":
|
528 |
-
other_pred_col = "rating_avg"
|
529 |
-
# Get unique comments, but fetch average labeler rating
|
530 |
-
num_toxic_other = len(df[(df.user_id == "A") & (df[other_pred_col] >= toxic_threshold)]) + smoothing_factor
|
531 |
-
num_nontoxic_other = len(df[(df.user_id == "A") & (df[other_pred_col] < toxic_threshold)]) + smoothing_factor
|
532 |
-
elif comparison_group == "other_users":
|
533 |
-
other_pred_col = "pred"
|
534 |
-
num_toxic_other = len(df[(df.user_id != "A") & (df[other_pred_col] >= toxic_threshold)]) + smoothing_factor
|
535 |
-
num_nontoxic_other = len(df[(df.user_id != "A") & (df[other_pred_col] < toxic_threshold)]) + smoothing_factor
|
536 |
-
|
537 |
-
num_toxic_user = len(df[(df.user_id == "A") & (df.pred >= toxic_threshold)]) + smoothing_factor
|
538 |
-
num_nontoxic_user = len(df[(df.user_id == "A") & (df.pred < toxic_threshold)]) + smoothing_factor
|
539 |
-
|
540 |
-
toxic_ratio = num_toxic_user / num_toxic_other
|
541 |
-
nontoxic_ratio = num_nontoxic_user / num_nontoxic_other
|
542 |
-
odds_ratio = toxic_ratio / nontoxic_ratio
|
543 |
-
|
544 |
-
if debug:
|
545 |
-
print(f"Odds ratio: {odds_ratio}")
|
546 |
-
print(f"num_toxic_user: {num_toxic_user}, num_nontoxic_user: {num_nontoxic_user}")
|
547 |
-
print(f"num_toxic_other: {num_toxic_other}, num_nontoxic_other: {num_nontoxic_other}")
|
548 |
-
|
549 |
-
contingency_table = [[num_toxic_user, num_nontoxic_user], [num_toxic_other, num_nontoxic_other]]
|
550 |
-
odds_ratio, p_val = stats.fisher_exact(contingency_table, alternative='two-sided')
|
551 |
-
if debug:
|
552 |
-
print(f"Odds ratio: {odds_ratio}, p={p_val}")
|
553 |
-
|
554 |
-
return odds_ratio
|
555 |
-
|
556 |
-
# Neighbor search
|
557 |
-
def get_match(comment_inds, K=20, threshold=None, debug=False):
|
558 |
-
match_ids = []
|
559 |
-
rows = []
|
560 |
-
for i in comment_inds:
|
561 |
-
if debug:
|
562 |
-
print(f"\nComment: {comments[i]}")
|
563 |
-
query_embedding = model.encode(comments[i], convert_to_tensor=True)
|
564 |
-
hits = util.semantic_search(query_embedding, embeddings, score_function=util.cos_sim, top_k=K)
|
565 |
-
# print(hits[0])
|
566 |
-
for hit in hits[0]:
|
567 |
-
c_id = hit['corpus_id']
|
568 |
-
score = np.round(hit['score'], 3)
|
569 |
-
if threshold is None or score > threshold:
|
570 |
-
match_ids.append(c_id)
|
571 |
-
if debug:
|
572 |
-
print(f"\t(ID={c_id}, Score={score}): {comments[c_id]}")
|
573 |
-
rows.append([c_id, score, comments[c_id]])
|
574 |
-
|
575 |
-
df = pd.DataFrame(rows, columns=["id", "score", "comment"])
|
576 |
-
return match_ids
|
577 |
-
|
578 |
-
def display_examples_auto_cluster(preds_df, cluster, other_ids, perf_metric, sort_ascending=True, worker_id="A", num_examples=10):
|
579 |
-
# Overall performance
|
580 |
-
topic_df = preds_df
|
581 |
-
topic_df = topic_df[topic_df["topic"] == cluster]
|
582 |
-
user_perf, other_perf = get_overall_perf3(topic_df, perf_metric, other_ids)
|
583 |
-
|
584 |
-
user_direction = "LOWER" if user_perf < 0 else "HIGHER"
|
585 |
-
other_direction = "LOWER" if other_perf < 0 else "HIGHER"
|
586 |
-
print(f"Your ratings are on average {np.round(abs(user_perf), 3)} {user_direction} than the existing system for this cluster")
|
587 |
-
print(f"Others' ratings (based on {len(other_ids)} users) are on average {np.round(abs(other_perf), 3)} {other_direction} than the existing system for this cluster")
|
588 |
-
|
589 |
-
# Display example comments
|
590 |
-
df = display_examples_cluster(preds_df, other_ids, num_examples, sort_ascending)
|
591 |
-
return df
|
592 |
-
|
593 |
-
|
594 |
-
# function to get results for a new provided cluster
|
595 |
-
def display_examples_manual_cluster(preds_df, cluster_comments, other_ids, perf_metric, sort_ascending=True, worker_id="A"):
|
596 |
-
# Overall performance
|
597 |
-
cluster_df = preds_df[preds_df["comment"].isin(cluster_comments)]
|
598 |
-
user_perf, other_perf = get_overall_perf3(cluster_df, perf_metric, other_ids)
|
599 |
-
|
600 |
-
user_direction = "LOWER" if user_perf < 0 else "HIGHER"
|
601 |
-
other_direction = "LOWER" if other_perf < 0 else "HIGHER"
|
602 |
-
print(f"Your ratings are on average {np.round(abs(user_perf), 3)} {user_direction} than the existing system for this cluster")
|
603 |
-
print(f"Others' ratings (based on {len(other_ids)} users) are on average {np.round(abs(other_perf), 3)} {other_direction} than the existing system for this cluster")
|
604 |
-
|
605 |
-
user_df = preds_df[preds_df.user_id == worker_id].sort_values(by=["item_id"]).reset_index()
|
606 |
-
others_df = preds_df[preds_df.user_id == other_ids[0]]
|
607 |
-
for i in range(1, len(other_ids)):
|
608 |
-
others_df.append(preds_df[preds_df.user_id == other_ids[i]])
|
609 |
-
others_df.groupby(["item_id"]).mean()
|
610 |
-
others_df = others_df.sort_values(by=["item_id"]).reset_index()
|
611 |
-
|
612 |
-
# Get cluster_comments
|
613 |
-
user_df = user_df[user_df["comment"].isin(cluster_comments)]
|
614 |
-
others_df = others_df[others_df["comment"].isin(cluster_comments)]
|
615 |
-
|
616 |
-
df = pd.merge(user_df, others_df, on="item_id", how="left", suffixes=('_user', '_other'))
|
617 |
-
df["pred_system"] = df["rating_avg_user"]
|
618 |
-
df["pred_system_stddev"] = df["rating_stddev_user"]
|
619 |
-
df = df[["item_id", "comment_user", "pred_user", "pred_other", "pred_system", "pred_system_stddev"]]
|
620 |
-
|
621 |
-
# Add styling
|
622 |
-
df = df.sort_values(by=['pred_system_stddev'], ascending=sort_ascending)
|
623 |
-
df = df.style.apply(style_color_difference, axis=1).render()
|
624 |
-
return df
|
625 |
-
|
626 |
########################################
|
627 |
# GET_LABELING utils
|
628 |
-
def create_example_sets(
|
629 |
# Restrict to the keyword, if provided
|
630 |
-
df =
|
631 |
if keyword != None:
|
632 |
df = df[df["comment"].str.contains(keyword)]
|
633 |
|
@@ -652,8 +233,8 @@ def create_example_sets(comments_df, n_label_per_bin, score_bins, keyword=None,
|
|
652 |
|
653 |
return ex_to_label
|
654 |
|
655 |
-
def get_grp_model_labels(
|
656 |
-
df =
|
657 |
|
658 |
train_df_grp = train_df[train_df["user_id"].isin(grp_ids)]
|
659 |
train_df_grp_avg = train_df_grp.groupby(by=["item_id"]).median().reset_index()
|
@@ -689,14 +270,7 @@ def fetch_existing_data(model_name, last_label_i):
|
|
689 |
with open(os.path.join(module_dir, perf_dir, f"{last_i}.pkl"), "rb") as f:
|
690 |
mae, mse, rmse, avg_diff = pickle.load(f)
|
691 |
else:
|
692 |
-
|
693 |
-
with open(os.path.join(module_dir, f"./data/trained_models/{model_name}.pkl"), "rb") as f:
|
694 |
-
cur_model = pickle.load(f)
|
695 |
-
mae, mse, rmse, avg_diff = users_perf(cur_model)
|
696 |
-
# Cache results
|
697 |
-
os.mkdir(os.path.join(module_dir, perf_dir))
|
698 |
-
with open(os.path.join(module_dir, perf_dir, "1.pkl"), "wb") as f:
|
699 |
-
pickle.dump((mae, mse, rmse, avg_diff), f)
|
700 |
|
701 |
# Fetch previous user-provided labels
|
702 |
ratings_prev = None
|
@@ -705,7 +279,16 @@ def fetch_existing_data(model_name, last_label_i):
|
|
705 |
ratings_prev = pickle.load(f)
|
706 |
return mae, mse, rmse, avg_diff, ratings_prev
|
707 |
|
708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
# Check if there is previously-labeled data; if so, combine it with this data
|
710 |
perf_dir = f"./data/perf/{model_name}"
|
711 |
label_dir = f"./data/labels/{model_name}"
|
@@ -716,9 +299,8 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=20, topic
|
|
716 |
labeled_df = labeled_df[labeled_df["rating"] != -1]
|
717 |
|
718 |
# Filter to top N for user study
|
719 |
-
if topic is None:
|
720 |
-
|
721 |
-
labeled_df = labeled_df.tail(top_n)
|
722 |
else:
|
723 |
# For topic tuning, need to fetch old labels
|
724 |
if (last_label_i > 0):
|
@@ -729,29 +311,29 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=20, topic
|
|
729 |
labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
|
730 |
ratings.update(ratings_prev) # append old ratings to ratings
|
731 |
labeled_df = pd.concat([labeled_df_prev, labeled_df])
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
cur_model, perf, _, _ = train_user_model(ratings_df=labeled_df)
|
736 |
-
|
737 |
-
user_perf_metrics[model_name] = users_perf(cur_model)
|
738 |
-
|
739 |
-
mae, mse, rmse, avg_diff = user_perf_metrics[model_name]
|
740 |
-
|
741 |
-
cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full) # Just get results for user
|
742 |
-
|
743 |
# Save this batch of labels
|
744 |
with open(os.path.join(module_dir, label_dir, f"{last_label_i + 1}.pkl"), "wb") as f:
|
745 |
pickle.dump(ratings, f)
|
746 |
|
747 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
with open(os.path.join(module_dir, f"./data/preds_dfs/{model_name}.pkl"), "wb") as f:
|
749 |
pickle.dump(cur_preds_df, f)
|
750 |
-
|
751 |
-
if model_name not in all_model_names:
|
752 |
-
all_model_names.append(model_name)
|
753 |
-
with open(os.path.join(module_dir, "./data/all_model_names.pkl"), "wb") as f:
|
754 |
-
pickle.dump(all_model_names, f)
|
755 |
|
756 |
# Handle user
|
757 |
if user not in users_to_models:
|
@@ -761,22 +343,10 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=20, topic
|
|
761 |
with open(f"./data/users_to_models.pkl", "wb") as f:
|
762 |
pickle.dump(users_to_models, f)
|
763 |
|
764 |
-
with open(os.path.join(module_dir, "./data/user_perf_metrics.pkl"), "wb") as f:
|
765 |
-
pickle.dump(user_perf_metrics, f)
|
766 |
-
with open(os.path.join(module_dir, f"./data/trained_models/{model_name}.pkl"), "wb") as f:
|
767 |
-
pickle.dump(cur_model, f)
|
768 |
-
|
769 |
-
# Cache performance results
|
770 |
-
if not os.path.isdir(os.path.join(module_dir, perf_dir)):
|
771 |
-
os.mkdir(os.path.join(module_dir, perf_dir))
|
772 |
-
last_perf_i = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
|
773 |
-
with open(os.path.join(module_dir, perf_dir, f"{last_perf_i + 1}.pkl"), "wb") as f:
|
774 |
-
pickle.dump((mae, mse, rmse, avg_diff), f)
|
775 |
-
|
776 |
ratings_prev = ratings
|
777 |
return mae, mse, rmse, avg_diff, ratings_prev
|
778 |
|
779 |
-
def format_labeled_data(ratings, worker_id="A"
|
780 |
all_rows = []
|
781 |
for comment, rating in ratings.items():
|
782 |
comment_id = comments_to_ids[comment]
|
@@ -786,7 +356,7 @@ def format_labeled_data(ratings, worker_id="A", debug=False):
|
|
786 |
df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
|
787 |
return df
|
788 |
|
789 |
-
def users_perf(model, sys_eval_df=sys_eval_df,
|
790 |
# Load the full empty dataset
|
791 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
792 |
empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
|
@@ -802,17 +372,17 @@ def users_perf(model, sys_eval_df=sys_eval_df, avg_ratings_df=comments_grouped_f
|
|
802 |
user_item_preds = get_predictions_by_user_and_item(predictions)
|
803 |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1)
|
804 |
|
805 |
-
df = df.merge(
|
806 |
df.dropna(subset = ["pred"], inplace=True)
|
807 |
-
df["
|
808 |
|
809 |
perf_metrics = get_overall_perf(df, "A") # mae, mse, rmse, avg_diff
|
810 |
return perf_metrics
|
811 |
|
812 |
def get_overall_perf(preds_df, user_id):
|
813 |
# Prepare dataset to calculate performance
|
814 |
-
y_pred = preds_df[preds_df["user_id"] == user_id].
|
815 |
-
y_true = preds_df[preds_df["user_id"] == user_id].pred.to_numpy()
|
816 |
|
817 |
# Get performance for user's model
|
818 |
mae = mean_absolute_error(y_true, y_pred)
|
@@ -831,9 +401,8 @@ def get_predictions_by_user_and_item(predictions):
|
|
831 |
# Pre-computes predictions for the provided model and specified users on the system-eval dataset
|
832 |
# - model: trained model
|
833 |
# - user_ids: list of user IDs to compute predictions for
|
834 |
-
# - avg_ratings_df: dataframe of average ratings for each comment (pre-computed)
|
835 |
# - sys_eval_df: dataframe of system eval labels (pre-computed)
|
836 |
-
def get_preds_df(model, user_ids,
|
837 |
# Prep dataframe for all predictions we'd like to request
|
838 |
start = time.time()
|
839 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
@@ -857,9 +426,9 @@ def get_preds_df(model, user_ids, avg_ratings_df=comments_grouped_full_topic_cat
|
|
857 |
df = empty_ratings_df.copy() # user_id, item_id, rating
|
858 |
user_item_preds = get_predictions_by_user_and_item(predictions)
|
859 |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1)
|
860 |
-
df = df.merge(
|
861 |
df.dropna(subset = ["pred"], inplace=True)
|
862 |
-
df["
|
863 |
|
864 |
# Get binned predictions (based on user prediction)
|
865 |
df["prediction_bin"], out_bins = pd.cut(df["pred"], bins, labels=False, retbins=True)
|
@@ -925,46 +494,6 @@ def train_model(train_df, model_eval_df, model_type="SVD", sim_type=None, user_b
|
|
925 |
|
926 |
return algo, perf
|
927 |
|
928 |
-
def plot_train_perf_results2(model_name):
|
929 |
-
# Open labels
|
930 |
-
label_dir = f"./data/labels/{model_name}"
|
931 |
-
n_label_files = len([name for name in os.listdir(os.path.join(module_dir, label_dir)) if os.path.isfile(os.path.join(module_dir, label_dir, name))])
|
932 |
-
|
933 |
-
all_rows = []
|
934 |
-
with open(os.path.join(module_dir, label_dir, f"{n_label_files}.pkl"), "rb") as f:
|
935 |
-
ratings = pickle.load(f)
|
936 |
-
|
937 |
-
labeled_df = format_labeled_data(ratings)
|
938 |
-
labeled_df = labeled_df[labeled_df["rating"] != -1]
|
939 |
-
|
940 |
-
# Iterate through batches of 5 labels
|
941 |
-
n_batches = int(np.ceil(len(labeled_df) / 5.))
|
942 |
-
for i in range(n_batches):
|
943 |
-
start = time.time()
|
944 |
-
n_to_sample = np.min([5 * (i + 1), len(labeled_df)])
|
945 |
-
cur_model, _, _, _ = train_user_model(ratings_df=labeled_df.head(n_to_sample))
|
946 |
-
mae, mse, rmse, avg_diff = users_perf(cur_model)
|
947 |
-
all_rows.append([n_to_sample, mae, "MAE"])
|
948 |
-
print(f"iter {i}: {time.time() - start}")
|
949 |
-
|
950 |
-
print("all_rows", all_rows)
|
951 |
-
|
952 |
-
df = pd.DataFrame(all_rows, columns=["n_to_sample", "perf", "metric"])
|
953 |
-
chart = alt.Chart(df).mark_line(point=True).encode(
|
954 |
-
x=alt.X("n_to_sample:Q", title="Number of Comments Labeled"),
|
955 |
-
y="perf",
|
956 |
-
color="metric",
|
957 |
-
tooltip=[
|
958 |
-
alt.Tooltip('n_to_sample:Q', title="Number of Comments Labeled"),
|
959 |
-
alt.Tooltip('metric:N', title="Metric"),
|
960 |
-
alt.Tooltip('perf:Q', title="Metric Value", format=".3f"),
|
961 |
-
],
|
962 |
-
).properties(
|
963 |
-
title=f"Performance over number of examples: {model_name}",
|
964 |
-
width=500,
|
965 |
-
)
|
966 |
-
return chart
|
967 |
-
|
968 |
def plot_train_perf_results(model_name, mae):
|
969 |
perf_dir = f"./data/perf/{model_name}"
|
970 |
n_perf_files = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
|
@@ -996,7 +525,7 @@ def plot_train_perf_results(model_name, mae):
|
|
996 |
|
997 |
plot_dim_width = 500
|
998 |
domain_min = 0.0
|
999 |
-
domain_max =
|
1000 |
bkgd = alt.Chart(pd.DataFrame({
|
1001 |
"start": [PCT_90, PCT_75, domain_min],
|
1002 |
"stop": [domain_max, PCT_90, PCT_75],
|
@@ -1119,14 +648,14 @@ def get_decision(rating, threshold):
|
|
1119 |
|
1120 |
def get_category(row, threshold=0.3):
|
1121 |
k_to_category = {
|
1122 |
-
"
|
1123 |
-
"
|
1124 |
-
"
|
1125 |
-
"
|
1126 |
-
"
|
1127 |
}
|
1128 |
categories = []
|
1129 |
-
for k in ["
|
1130 |
if row[k] > threshold:
|
1131 |
categories.append(k_to_category[k])
|
1132 |
|
@@ -1139,20 +668,20 @@ def get_comment_url(row):
|
|
1139 |
return f"#{row['item_id']}/#comment"
|
1140 |
|
1141 |
def get_topic_url(row):
|
1142 |
-
return f"#{row['
|
1143 |
|
1144 |
# Plots overall results histogram (each block is a topic)
|
1145 |
-
def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD,
|
1146 |
df = preds_df.copy().reset_index()
|
1147 |
|
1148 |
if n_topics is not None:
|
1149 |
-
df = df[df["
|
1150 |
|
1151 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
1152 |
df = df[df["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
|
1153 |
-
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[
|
1154 |
-
df["threshold"] = [threshold for r in df[
|
1155 |
-
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[
|
1156 |
df["url"] = df.apply(lambda row: get_topic_url(row), axis=1)
|
1157 |
|
1158 |
# Plot sizing
|
@@ -1170,7 +699,7 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
|
|
1170 |
# Main chart
|
1171 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.5).transform_window(
|
1172 |
groupby=['vis_pred_bin'],
|
1173 |
-
sort=[{'field':
|
1174 |
id='row_number()',
|
1175 |
ignorePeers=True,
|
1176 |
).encode(
|
@@ -1183,9 +712,9 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
|
|
1183 |
),
|
1184 |
href="url:N",
|
1185 |
tooltip = [
|
1186 |
-
alt.Tooltip("
|
1187 |
alt.Tooltip("system_label:N", title="System label"),
|
1188 |
-
alt.Tooltip("
|
1189 |
alt.Tooltip("pred:Q", title="Your rating", format=".2f")
|
1190 |
]
|
1191 |
)
|
@@ -1260,13 +789,13 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
|
|
1260 |
|
1261 |
# Plots cluster results histogram (each block is a comment), but *without* a model
|
1262 |
# as a point of reference (in contrast to plot_overall_vis_cluster)
|
1263 |
-
def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD,
|
1264 |
df = preds_df.copy().reset_index()
|
1265 |
|
1266 |
-
df["vis_pred_bin"], out_bins = pd.cut(df[
|
1267 |
-
df = df[df["user_id"] == "A"].sort_values(by=[
|
1268 |
-
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[
|
1269 |
-
df["key"] = [get_key_no_model(sys, threshold) for sys in df[
|
1270 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
1271 |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
|
1272 |
|
@@ -1288,7 +817,7 @@ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS,
|
|
1288 |
# Main chart
|
1289 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window(
|
1290 |
groupby=['vis_pred_bin'],
|
1291 |
-
sort=[{'field':
|
1292 |
id='row_number()',
|
1293 |
ignorePeers=True
|
1294 |
).encode(
|
@@ -1302,8 +831,8 @@ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS,
|
|
1302 |
),
|
1303 |
href="url:N",
|
1304 |
tooltip = [
|
1305 |
-
alt.Tooltip("
|
1306 |
-
alt.Tooltip("
|
1307 |
]
|
1308 |
)
|
1309 |
|
@@ -1356,20 +885,20 @@ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS,
|
|
1356 |
return final_plot, df
|
1357 |
|
1358 |
# Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
|
1359 |
-
def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD,
|
1360 |
-
df = preds_df.copy().reset_index(
|
1361 |
|
1362 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
1363 |
-
df = df[df["user_id"] == "A"].sort_values(by=[
|
1364 |
-
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[
|
1365 |
-
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[
|
1366 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
1367 |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
|
1368 |
|
1369 |
if n_comments is not None:
|
1370 |
n_to_sample = np.min([n_comments, len(df)])
|
1371 |
df = df.sample(n=n_to_sample)
|
1372 |
-
|
1373 |
# Plot sizing
|
1374 |
domain_min = 0
|
1375 |
domain_max = 4
|
@@ -1384,7 +913,7 @@ def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BIN
|
|
1384 |
# Main chart
|
1385 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window(
|
1386 |
groupby=['vis_pred_bin'],
|
1387 |
-
sort=[{'field':
|
1388 |
id='row_number()',
|
1389 |
ignorePeers=True
|
1390 |
).encode(
|
@@ -1397,8 +926,8 @@ def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BIN
|
|
1397 |
),
|
1398 |
href="url:N",
|
1399 |
tooltip = [
|
1400 |
-
alt.Tooltip("
|
1401 |
-
alt.Tooltip("
|
1402 |
alt.Tooltip("pred:Q", title="Your rating", format=".2f"),
|
1403 |
alt.Tooltip("category:N", title="Potential toxicity categories")
|
1404 |
]
|
@@ -1464,30 +993,26 @@ def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BIN
|
|
1464 |
|
1465 |
return final_plot, df
|
1466 |
|
1467 |
-
def get_cluster_comments(df, error_type, threshold=TOXIC_THRESHOLD,
|
1468 |
df["user_color"] = [get_user_color(user, threshold) for user in df["pred"].tolist()] # get cell colors
|
1469 |
-
df["system_color"] = [get_user_color(sys, threshold) for sys in df[
|
1470 |
-
df["error_color"] = [get_system_color(sys, user, threshold) for sys, user in zip(df[
|
1471 |
-
df["error_type"] = [get_error_type(sys, user, threshold) for sys, user in zip(df[
|
1472 |
-
df["error_amt"] = [abs(sys - threshold) for sys in df[
|
1473 |
df["judgment"] = ["" for _ in range(len(df))] # template for "agree" or "disagree" buttons
|
1474 |
|
1475 |
if use_model:
|
1476 |
df = df.sort_values(by=["error_amt"], ascending=False) # surface largest errors first
|
1477 |
else:
|
1478 |
print("get_cluster_comments; not using model")
|
1479 |
-
df = df.sort_values(by=[
|
1480 |
|
1481 |
df["id"] = df["item_id"]
|
1482 |
-
# df["comment"] already exists
|
1483 |
-
df["comment"] = df["comment_"]
|
1484 |
df["toxicity_category"] = df["category"]
|
1485 |
df["user_rating"] = df["pred"]
|
1486 |
df["user_decision"] = [get_decision(rating, threshold) for rating in df["pred"].tolist()]
|
1487 |
-
df["system_rating"] = df[
|
1488 |
-
df["system_decision"] = [get_decision(rating, threshold) for rating in df[
|
1489 |
-
df["error_type"] = df["error_type"]
|
1490 |
-
df = df.head(num_examples)
|
1491 |
df = df.round(decimals=2)
|
1492 |
|
1493 |
# Filter to specified error type
|
@@ -1500,7 +1025,7 @@ def get_cluster_comments(df, error_type, threshold=TOXIC_THRESHOLD, worker_id="A
|
|
1500 |
elif error_type == "Both":
|
1501 |
df = df[(df["error_type"] == "System may be under-sensitive") | (df["error_type"] == "System may be over-sensitive")]
|
1502 |
|
1503 |
-
return df
|
1504 |
|
1505 |
# PERSONALIZED CLUSTERS utils
|
1506 |
def get_disagreement_comments(preds_df, mode, n=10_000, threshold=TOXIC_THRESHOLD):
|
@@ -1519,58 +1044,10 @@ def get_disagreement_comments(preds_df, mode, n=10_000, threshold=TOXIC_THRESHOL
|
|
1519 |
df = df.sort_values(by=["diff"], ascending=asc)
|
1520 |
df = df.head(n)
|
1521 |
|
1522 |
-
return df["
|
1523 |
-
|
1524 |
-
def
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
cluster_df = cluster_df.sort_values(by=["topic_id"])
|
1530 |
-
topics_under = cluster_df[cluster_df["error_type"] == "System may be under-sensitive"]["topic"].unique().tolist()
|
1531 |
-
topics_under = topics_under[1:(n + 1)]
|
1532 |
-
topics_over = cluster_df[cluster_df["error_type"] == "System may be over-sensitive"]["topic"].unique().tolist()
|
1533 |
-
topics_over = topics_over[1:(n + 1)]
|
1534 |
-
return topics_under, topics_over
|
1535 |
-
else:
|
1536 |
-
topics_under_top = []
|
1537 |
-
topics_over_top = []
|
1538 |
-
preds_df_file = f"./data/preds_dfs/{model}.pkl"
|
1539 |
-
if (os.path.isfile(preds_df_file)):
|
1540 |
-
with open(preds_df_file, "rb") as f:
|
1541 |
-
preds_df = pickle.load(f)
|
1542 |
-
preds_df_mod = preds_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg')).reset_index()
|
1543 |
-
preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == "A"]
|
1544 |
-
|
1545 |
-
comments_under, comments_under_df = get_disagreement_comments(preds_df_mod, mode="under-sensitive", n=1000)
|
1546 |
-
if len(comments_under) > 0:
|
1547 |
-
topics_under = BERTopic(embedding_model="paraphrase-MiniLM-L6-v2").fit(comments_under)
|
1548 |
-
topics_under_top = topics_under.get_topic_info().head(n)["Name"].tolist()
|
1549 |
-
print("topics_under", topics_under_top)
|
1550 |
-
# Get topics per comment
|
1551 |
-
topics_assigned, _ = topics_under.transform(comments_under)
|
1552 |
-
comments_under_df["topic_id"] = topics_assigned
|
1553 |
-
cur_topic_ids = topics_under.get_topic_info().Topic
|
1554 |
-
topic_short_names = topics_under.get_topic_info().Name
|
1555 |
-
topic_ids_to_names = {cur_topic_ids[i]: topic_short_names[i] for i in range(len(cur_topic_ids))}
|
1556 |
-
comments_under_df["topic"] = [topic_ids_to_names[topic_id] for topic_id in comments_under_df["topic_id"].tolist()]
|
1557 |
-
|
1558 |
-
comments_over, comments_over_df = get_disagreement_comments(preds_df_mod, mode="over-sensitive", n=1000)
|
1559 |
-
if len(comments_over) > 0:
|
1560 |
-
topics_over = BERTopic(embedding_model="paraphrase-MiniLM-L6-v2").fit(comments_over)
|
1561 |
-
topics_over_top = topics_over.get_topic_info().head(n)["Name"].tolist()
|
1562 |
-
print("topics_over", topics_over_top)
|
1563 |
-
# Get topics per comment
|
1564 |
-
topics_assigned, _ = topics_over.transform(comments_over)
|
1565 |
-
comments_over_df["topic_id"] = topics_assigned
|
1566 |
-
cur_topic_ids = topics_over.get_topic_info().Topic
|
1567 |
-
topic_short_names = topics_over.get_topic_info().Name
|
1568 |
-
topic_ids_to_names = {cur_topic_ids[i]: topic_short_names[i] for i in range(len(cur_topic_ids))}
|
1569 |
-
comments_over_df["topic"] = [topic_ids_to_names[topic_id] for topic_id in comments_over_df["topic_id"].tolist()]
|
1570 |
-
|
1571 |
-
cluster_df = pd.concat([comments_under_df, comments_over_df])
|
1572 |
-
with open(f"./data/personal_cluster_dfs/{model}.pkl", "wb") as f:
|
1573 |
-
pickle.dump(cluster_df, f)
|
1574 |
-
|
1575 |
-
return topics_under_top, topics_over_top
|
1576 |
-
return [], []
|
|
|
40 |
perf_dir = f"data/perf/"
|
41 |
|
42 |
# # TEMP reset
|
|
|
|
|
|
|
43 |
# with open(f"./data/users_to_models.pkl", "wb") as f:
|
44 |
# users_to_models = {}
|
45 |
# pickle.dump(users_to_models, f)
|
46 |
|
47 |
+
with open(os.path.join(module_dir, "data/input/ids_to_comments.pkl"), "rb") as f:
|
|
|
48 |
ids_to_comments = pickle.load(f)
|
49 |
+
with open(os.path.join(module_dir, "data/input/comments_to_ids.pkl"), "rb") as f:
|
50 |
comments_to_ids = pickle.load(f)
|
51 |
+
system_preds_df = pd.read_pickle("data/input/system_preds_df.pkl")
|
52 |
+
sys_eval_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/sys_eval_df.pkl"))
|
53 |
+
train_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/train_df.pkl"))
|
|
|
|
|
54 |
train_df_ids = train_df["item_id"].unique().tolist()
|
55 |
+
model_eval_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/model_eval_df.pkl"))
|
56 |
+
ratings_df_full = pd.read_pickle(os.path.join(module_dir, "data/input/ratings_df_full.pkl"))
|
57 |
+
worker_info_df = pd.read_pickle("./data/input/worker_info_df.pkl")
|
|
|
58 |
|
59 |
with open(f"./data/users_to_models.pkl", "rb") as f:
|
60 |
users_to_models = pickle.load(f)
|
61 |
|
62 |
+
topic_ids = system_preds_df.topic_id
|
63 |
+
topics = system_preds_df.topic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
topic_ids_to_topics = {topic_ids[i]: topics[i] for i in range(len(topic_ids))}
|
65 |
topics_to_topic_ids = {topics[i]: topic_ids[i] for i in range(len(topic_ids))}
|
66 |
+
unique_topics_ids = sorted(system_preds_df.topic_id.unique())
|
67 |
unique_topics = [topic_ids_to_topics[topic_id] for topic_id in range(len(topic_ids_to_topics) - 1)]
|
68 |
|
69 |
def get_toxic_threshold():
|
70 |
return TOXIC_THRESHOLD
|
71 |
|
72 |
+
def get_user_model_names(user):
|
73 |
+
# Fetch the user's models
|
74 |
+
if user not in users_to_models:
|
75 |
+
users_to_models[user] = []
|
76 |
+
user_models = users_to_models[user]
|
77 |
+
user_models.sort()
|
78 |
+
return user_models
|
|
|
|
|
79 |
|
80 |
def get_unique_topics():
|
81 |
return unique_topics
|
82 |
|
83 |
def get_large_clusters(min_n):
|
84 |
+
counts_df = system_preds_df.groupby(by=["topic_id"]).size().reset_index(name='counts')
|
85 |
counts_df = counts_df[counts_df["counts"] >= min_n]
|
86 |
return [topic_ids_to_topics[t_id] for t_id in sorted(counts_df["topic_id"].tolist()[1:])]
|
87 |
|
|
|
119 |
}
|
120 |
internal_to_readable = {v: k for k, v in readable_to_internal.items()}
|
121 |
|
122 |
+
def get_system_preds_df():
|
123 |
+
return system_preds_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
########################################
|
126 |
# General utils
|
|
|
150 |
|
151 |
########################################
|
152 |
# GET_AUDIT utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
|
154 |
hist, bin_edges = np.histogram(other_metric_vals, bins=n_bins, density=False)
|
155 |
data = pd.DataFrame({
|
|
|
181 |
|
182 |
return (bar + rule).interactive()
|
183 |
|
184 |
+
# Generates the summary plot across all topics for the user
|
185 |
+
def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median"):
|
186 |
+
# Your perf (calculate using model and testset)
|
187 |
+
with open(os.path.join(module_dir, f"data/preds_dfs/{variant}.pkl"), "rb") as f:
|
188 |
+
preds_df = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Read from file
|
191 |
+
chart_dir = "./data/charts"
|
192 |
+
chart_file = os.path.join(chart_dir, f"{cur_user}_{variant}.pkl")
|
193 |
+
if os.path.isfile(chart_file):
|
194 |
+
with open(chart_file, "r") as f:
|
195 |
+
topic_overview_plot_json = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
else:
|
197 |
+
if topic_vis_method == "median": # Default
|
198 |
+
preds_df_grp = preds_df.groupby(["topic", "user_id"]).median()
|
199 |
+
elif topic_vis_method == "mean":
|
200 |
+
preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
|
201 |
+
topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=variant)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
return {
|
204 |
"topic_overview_plot_json": json.loads(topic_overview_plot_json),
|
205 |
}
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
########################################
|
208 |
# GET_LABELING utils
|
209 |
+
def create_example_sets(n_label_per_bin, score_bins, keyword=None, topic=None):
|
210 |
# Restrict to the keyword, if provided
|
211 |
+
df = system_preds_df.copy()
|
212 |
if keyword != None:
|
213 |
df = df[df["comment"].str.contains(keyword)]
|
214 |
|
|
|
233 |
|
234 |
return ex_to_label
|
235 |
|
236 |
+
def get_grp_model_labels(n_label_per_bin, score_bins, grp_ids):
|
237 |
+
df = system_preds_df.copy()
|
238 |
|
239 |
train_df_grp = train_df[train_df["user_id"].isin(grp_ids)]
|
240 |
train_df_grp_avg = train_df_grp.groupby(by=["item_id"]).median().reset_index()
|
|
|
270 |
with open(os.path.join(module_dir, perf_dir, f"{last_i}.pkl"), "rb") as f:
|
271 |
mae, mse, rmse, avg_diff = pickle.load(f)
|
272 |
else:
|
273 |
+
raise Exception(f"Model {model_name} does not exist")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
# Fetch previous user-provided labels
|
276 |
ratings_prev = None
|
|
|
279 |
ratings_prev = pickle.load(f)
|
280 |
return mae, mse, rmse, avg_diff, ratings_prev
|
281 |
|
282 |
+
# Main function called by server's `get_personalized_model` endpoint
|
283 |
+
# Trains an updated model with the specified name, user, and ratings
|
284 |
+
# Saves ratings, performance metrics, and pre-computed predictions to files
|
285 |
+
# - model_name: name of the model to train
|
286 |
+
# - last_label_i: index of the last label file (0 if none exists)
|
287 |
+
# - ratings: dictionary of comments to ratings
|
288 |
+
# - user: user name
|
289 |
+
# - top_n: number of comments to train on (used when a set was held out for original user study)
|
290 |
+
# - topic: topic to train on (used when tuning for a specific topic)
|
291 |
+
def train_updated_model(model_name, last_label_i, ratings, user, top_n=None, topic=None, debug=False):
|
292 |
# Check if there is previously-labeled data; if so, combine it with this data
|
293 |
perf_dir = f"./data/perf/{model_name}"
|
294 |
label_dir = f"./data/labels/{model_name}"
|
|
|
299 |
labeled_df = labeled_df[labeled_df["rating"] != -1]
|
300 |
|
301 |
# Filter to top N for user study
|
302 |
+
if (topic is None) and (top_n is not None):
|
303 |
+
labeled_df = labeled_df.head(top_n)
|
|
|
304 |
else:
|
305 |
# For topic tuning, need to fetch old labels
|
306 |
if (last_label_i > 0):
|
|
|
311 |
labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
|
312 |
ratings.update(ratings_prev) # append old ratings to ratings
|
313 |
labeled_df = pd.concat([labeled_df_prev, labeled_df])
|
314 |
+
if debug:
|
315 |
+
print("len ratings for training:", len(labeled_df))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
# Save this batch of labels
|
317 |
with open(os.path.join(module_dir, label_dir, f"{last_label_i + 1}.pkl"), "wb") as f:
|
318 |
pickle.dump(ratings, f)
|
319 |
|
320 |
+
# Train model
|
321 |
+
cur_model, _, _, _ = train_user_model(ratings_df=labeled_df)
|
322 |
+
|
323 |
+
# Compute performance metrics
|
324 |
+
mae, mse, rmse, avg_diff = users_perf(cur_model)
|
325 |
+
# Save performance metrics
|
326 |
+
if not os.path.isdir(os.path.join(module_dir, perf_dir)):
|
327 |
+
os.mkdir(os.path.join(module_dir, perf_dir))
|
328 |
+
last_perf_i = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
|
329 |
+
with open(os.path.join(module_dir, perf_dir, f"{last_perf_i + 1}.pkl"), "wb") as f:
|
330 |
+
pickle.dump((mae, mse, rmse, avg_diff), f)
|
331 |
+
|
332 |
+
# Pre-compute predictions for full dataset
|
333 |
+
cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full)
|
334 |
+
# Save pre-computed predictions
|
335 |
with open(os.path.join(module_dir, f"./data/preds_dfs/{model_name}.pkl"), "wb") as f:
|
336 |
pickle.dump(cur_preds_df, f)
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
# Handle user
|
339 |
if user not in users_to_models:
|
|
|
343 |
with open(f"./data/users_to_models.pkl", "wb") as f:
|
344 |
pickle.dump(users_to_models, f)
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
ratings_prev = ratings
|
347 |
return mae, mse, rmse, avg_diff, ratings_prev
|
348 |
|
349 |
+
def format_labeled_data(ratings, worker_id="A"):
|
350 |
all_rows = []
|
351 |
for comment, rating in ratings.items():
|
352 |
comment_id = comments_to_ids[comment]
|
|
|
356 |
df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"])
|
357 |
return df
|
358 |
|
359 |
+
def users_perf(model, sys_eval_df=sys_eval_df, worker_id="A"):
|
360 |
# Load the full empty dataset
|
361 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
362 |
empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids]
|
|
|
372 |
user_item_preds = get_predictions_by_user_and_item(predictions)
|
373 |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1)
|
374 |
|
375 |
+
df = df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
|
376 |
df.dropna(subset = ["pred"], inplace=True)
|
377 |
+
df["rating"] = df.rating.astype("int32")
|
378 |
|
379 |
perf_metrics = get_overall_perf(df, "A") # mae, mse, rmse, avg_diff
|
380 |
return perf_metrics
|
381 |
|
382 |
def get_overall_perf(preds_df, user_id):
|
383 |
# Prepare dataset to calculate performance
|
384 |
+
y_pred = preds_df[preds_df["user_id"] == user_id].rating_sys.to_numpy() # system's prediction
|
385 |
+
y_true = preds_df[preds_df["user_id"] == user_id].pred.to_numpy() # user's (predicted) ground truth
|
386 |
|
387 |
# Get performance for user's model
|
388 |
mae = mean_absolute_error(y_true, y_pred)
|
|
|
401 |
# Pre-computes predictions for the provided model and specified users on the system-eval dataset
|
402 |
# - model: trained model
|
403 |
# - user_ids: list of user IDs to compute predictions for
|
|
|
404 |
# - sys_eval_df: dataframe of system eval labels (pre-computed)
|
405 |
+
def get_preds_df(model, user_ids, sys_eval_df=sys_eval_df, bins=BINS):
|
406 |
# Prep dataframe for all predictions we'd like to request
|
407 |
start = time.time()
|
408 |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
|
|
|
426 |
df = empty_ratings_df.copy() # user_id, item_id, rating
|
427 |
user_item_preds = get_predictions_by_user_and_item(predictions)
|
428 |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1)
|
429 |
+
df = df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
|
430 |
df.dropna(subset = ["pred"], inplace=True)
|
431 |
+
df["rating"] = df.rating.astype("int32")
|
432 |
|
433 |
# Get binned predictions (based on user prediction)
|
434 |
df["prediction_bin"], out_bins = pd.cut(df["pred"], bins, labels=False, retbins=True)
|
|
|
494 |
|
495 |
return algo, perf
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
def plot_train_perf_results(model_name, mae):
|
498 |
perf_dir = f"./data/perf/{model_name}"
|
499 |
n_perf_files = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
|
|
|
525 |
|
526 |
plot_dim_width = 500
|
527 |
domain_min = 0.0
|
528 |
+
domain_max = 2.0
|
529 |
bkgd = alt.Chart(pd.DataFrame({
|
530 |
"start": [PCT_90, PCT_75, domain_min],
|
531 |
"stop": [domain_max, PCT_90, PCT_75],
|
|
|
648 |
|
649 |
def get_category(row, threshold=0.3):
|
650 |
k_to_category = {
|
651 |
+
"is_profane_frac": "Profanity",
|
652 |
+
"is_threat_frac": "Threat",
|
653 |
+
"is_identity_attack_frac": "Identity Attack",
|
654 |
+
"is_insult_frac": "Insult",
|
655 |
+
"is_sexual_harassment_frac": "Sexual Harassment",
|
656 |
}
|
657 |
categories = []
|
658 |
+
for k in ["is_profane_frac", "is_threat_frac", "is_identity_attack_frac", "is_insult_frac", "is_sexual_harassment_frac"]:
|
659 |
if row[k] > threshold:
|
660 |
categories.append(k_to_category[k])
|
661 |
|
|
|
668 |
return f"#{row['item_id']}/#comment"
|
669 |
|
670 |
def get_topic_url(row):
|
671 |
+
return f"#{row['topic']}/#topic"
|
672 |
|
673 |
# Plots overall results histogram (each block is a topic)
|
674 |
+
def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
675 |
df = preds_df.copy().reset_index()
|
676 |
|
677 |
if n_topics is not None:
|
678 |
+
df = df[df["topic_id"] < n_topics]
|
679 |
|
680 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
681 |
df = df[df["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
|
682 |
+
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
683 |
+
df["threshold"] = [threshold for r in df[sys_col].tolist()]
|
684 |
+
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
685 |
df["url"] = df.apply(lambda row: get_topic_url(row), axis=1)
|
686 |
|
687 |
# Plot sizing
|
|
|
699 |
# Main chart
|
700 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.5).transform_window(
|
701 |
groupby=['vis_pred_bin'],
|
702 |
+
sort=[{'field': sys_col}],
|
703 |
id='row_number()',
|
704 |
ignorePeers=True,
|
705 |
).encode(
|
|
|
712 |
),
|
713 |
href="url:N",
|
714 |
tooltip = [
|
715 |
+
alt.Tooltip("topic:N", title="Topic"),
|
716 |
alt.Tooltip("system_label:N", title="System label"),
|
717 |
+
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"),
|
718 |
alt.Tooltip("pred:Q", title="Your rating", format=".2f")
|
719 |
]
|
720 |
)
|
|
|
789 |
|
790 |
# Plots cluster results histogram (each block is a comment), but *without* a model
|
791 |
# as a point of reference (in contrast to plot_overall_vis_cluster)
|
792 |
+
def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
793 |
df = preds_df.copy().reset_index()
|
794 |
|
795 |
+
df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True)
|
796 |
+
df = df[df["user_id"] == "A"].sort_values(by=[sys_col]).reset_index()
|
797 |
+
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
798 |
+
df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()]
|
799 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
800 |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
|
801 |
|
|
|
817 |
# Main chart
|
818 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window(
|
819 |
groupby=['vis_pred_bin'],
|
820 |
+
sort=[{'field': sys_col}],
|
821 |
id='row_number()',
|
822 |
ignorePeers=True
|
823 |
).encode(
|
|
|
831 |
),
|
832 |
href="url:N",
|
833 |
tooltip = [
|
834 |
+
alt.Tooltip("comment:N", title="comment"),
|
835 |
+
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"),
|
836 |
]
|
837 |
)
|
838 |
|
|
|
885 |
return final_plot, df
|
886 |
|
887 |
# Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
|
888 |
+
def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"):
|
889 |
+
df = preds_df.copy().reset_index()
|
890 |
|
891 |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
|
892 |
+
df = df[df["user_id"] == "A"].sort_values(by=[sys_col]).reset_index(drop=True)
|
893 |
+
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()]
|
894 |
+
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())]
|
895 |
df["category"] = df.apply(lambda row: get_category(row), axis=1)
|
896 |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
|
897 |
|
898 |
if n_comments is not None:
|
899 |
n_to_sample = np.min([n_comments, len(df)])
|
900 |
df = df.sample(n=n_to_sample)
|
901 |
+
|
902 |
# Plot sizing
|
903 |
domain_min = 0
|
904 |
domain_max = 4
|
|
|
913 |
# Main chart
|
914 |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window(
|
915 |
groupby=['vis_pred_bin'],
|
916 |
+
sort=[{'field': sys_col}],
|
917 |
id='row_number()',
|
918 |
ignorePeers=True
|
919 |
).encode(
|
|
|
926 |
),
|
927 |
href="url:N",
|
928 |
tooltip = [
|
929 |
+
alt.Tooltip("comment:N", title="comment"),
|
930 |
+
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"),
|
931 |
alt.Tooltip("pred:Q", title="Your rating", format=".2f"),
|
932 |
alt.Tooltip("category:N", title="Potential toxicity categories")
|
933 |
]
|
|
|
993 |
|
994 |
return final_plot, df
|
995 |
|
996 |
+
def get_cluster_comments(df, error_type, threshold=TOXIC_THRESHOLD, sys_col="rating_sys", use_model=True):
|
997 |
df["user_color"] = [get_user_color(user, threshold) for user in df["pred"].tolist()] # get cell colors
|
998 |
+
df["system_color"] = [get_user_color(sys, threshold) for sys in df[sys_col].tolist()] # get cell colors
|
999 |
+
df["error_color"] = [get_system_color(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] # get cell colors
|
1000 |
+
df["error_type"] = [get_error_type(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] # get error type in words
|
1001 |
+
df["error_amt"] = [abs(sys - threshold) for sys in df[sys_col].tolist()] # get raw error
|
1002 |
df["judgment"] = ["" for _ in range(len(df))] # template for "agree" or "disagree" buttons
|
1003 |
|
1004 |
if use_model:
|
1005 |
df = df.sort_values(by=["error_amt"], ascending=False) # surface largest errors first
|
1006 |
else:
|
1007 |
print("get_cluster_comments; not using model")
|
1008 |
+
df = df.sort_values(by=[sys_col], ascending=True)
|
1009 |
|
1010 |
df["id"] = df["item_id"]
|
|
|
|
|
1011 |
df["toxicity_category"] = df["category"]
|
1012 |
df["user_rating"] = df["pred"]
|
1013 |
df["user_decision"] = [get_decision(rating, threshold) for rating in df["pred"].tolist()]
|
1014 |
+
df["system_rating"] = df[sys_col]
|
1015 |
+
df["system_decision"] = [get_decision(rating, threshold) for rating in df[sys_col].tolist()]
|
|
|
|
|
1016 |
df = df.round(decimals=2)
|
1017 |
|
1018 |
# Filter to specified error type
|
|
|
1025 |
elif error_type == "Both":
|
1026 |
df = df[(df["error_type"] == "System may be under-sensitive") | (df["error_type"] == "System may be over-sensitive")]
|
1027 |
|
1028 |
+
return df
|
1029 |
|
1030 |
# PERSONALIZED CLUSTERS utils
|
1031 |
def get_disagreement_comments(preds_df, mode, n=10_000, threshold=TOXIC_THRESHOLD):
|
|
|
1044 |
df = df.sort_values(by=["diff"], ascending=asc)
|
1045 |
df = df.head(n)
|
1046 |
|
1047 |
+
return df["comment"].tolist(), df
|
1048 |
+
|
1049 |
+
def get_explore_df(n_examples, threshold):
|
1050 |
+
df = system_preds_df.sample(n=n_examples)
|
1051 |
+
df["system_decision"] = [get_decision(rating, threshold) for rating in df["rating"].tolist()]
|
1052 |
+
df["system_color"] = [get_user_color(sys, threshold) for sys in df["rating"].tolist()] # get cell colors
|
1053 |
+
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indie_label_svelte/src/Auditing.svelte
CHANGED
@@ -51,8 +51,6 @@
|
|
51 |
]
|
52 |
|
53 |
let personalized_models = [];
|
54 |
-
let breakdown_category;
|
55 |
-
let breakdown_categories = [];
|
56 |
let systems = ["YouSocial comment toxicity classifier"]; // Only one system for now
|
57 |
let clusters = [];
|
58 |
let clusters_for_tuning = []
|
@@ -72,7 +70,6 @@
|
|
72 |
let audit_type;
|
73 |
if (scaffold_method == "fixed" || scaffold_method == "personal" || scaffold_method == "personal_group" || scaffold_method == "personal_test" || scaffold_method == "personal_cluster" || scaffold_method == "topic_train" || scaffold_method == "prompts") {
|
74 |
audit_type = audit_types[1];
|
75 |
-
// audit_type = audit_types[0];
|
76 |
} else {
|
77 |
// No scaffolding mode or tutorial
|
78 |
audit_type = audit_types[0];
|
@@ -112,7 +109,7 @@
|
|
112 |
if (!personalized_models.includes(personalized_model)) {
|
113 |
personalized_models.push(personalized_model);
|
114 |
}
|
115 |
-
|
116 |
handleClusterButton(); // re-render cluster results
|
117 |
});
|
118 |
|
@@ -142,8 +139,6 @@
|
|
142 |
.then((r) => r.text())
|
143 |
.then(function (r_orig) {
|
144 |
let r = JSON.parse(r_orig);
|
145 |
-
breakdown_categories = r["breakdown_categories"];
|
146 |
-
breakdown_category = breakdown_categories[0];
|
147 |
personalized_models = r["personalized_models"];
|
148 |
if (use_group_model) {
|
149 |
let personalized_model_grp = r["personalized_model_grp"];
|
@@ -173,7 +168,6 @@
|
|
173 |
async function getAudit() {
|
174 |
let req_params = {
|
175 |
pers_model: personalized_model,
|
176 |
-
breakdown_axis: breakdown_category,
|
177 |
perf_metric: "avg_diff",
|
178 |
breakdown_sort: "difference",
|
179 |
n_topics: 10,
|
@@ -199,13 +193,11 @@
|
|
199 |
let req_params = {
|
200 |
cluster: topic,
|
201 |
topic_df_ids: [],
|
202 |
-
n_examples: 500, // TEMP
|
203 |
pers_model: personalized_model,
|
204 |
example_sort: "descending", // TEMP
|
205 |
comparison_group: "status_quo", // TEMP
|
206 |
search_type: "cluster",
|
207 |
keyword: "",
|
208 |
-
n_neighbors: 0,
|
209 |
error_type: cur_error_type,
|
210 |
use_model: use_model,
|
211 |
scaffold_method: scaffold_method,
|
@@ -223,16 +215,13 @@
|
|
223 |
<div>
|
224 |
<div style="margin-top: 30px">
|
225 |
<span class="head_3">Auditing</span>
|
226 |
-
<IconButton
|
227 |
-
class="material-icons grey_button"
|
228 |
-
size="normal"
|
229 |
-
on:click={() => (show_audit_settings = !show_audit_settings)}
|
230 |
-
>
|
231 |
-
help_outline
|
232 |
-
</IconButton>
|
233 |
</div>
|
234 |
<div style="width: 80%">
|
|
|
235 |
<p>In this section, we'll be auditing the content moderation system. Here, you’ll be aided by a personalized model that will help direct your attention towards potential problem areas in the model’s performance. This model isn’t meant to be perfect, but is designed to help you better focus on areas that need human review.</p>
|
|
|
|
|
|
|
236 |
</div>
|
237 |
|
238 |
{#if show_audit_settings}
|
@@ -282,11 +271,14 @@
|
|
282 |
</LayoutGrid>
|
283 |
</div>
|
284 |
</div>
|
|
|
|
|
285 |
<p>Current model: {personalized_model}</p>
|
286 |
{/if}
|
287 |
</div>
|
288 |
|
289 |
<!-- 1: All topics overview -->
|
|
|
290 |
{#if audit_type == audit_types[0]}
|
291 |
<div class="audit_section">
|
292 |
<div class="head_5">Overview of all topics</div>
|
@@ -440,7 +432,7 @@
|
|
440 |
<div class="head_5">Finalize your current report</div>
|
441 |
<p>Finally, review the report you've generated on the side panel and provide a brief summary of the problem you see. You may also list suggestions or insights into addressing this problem if you have ideas. This report will be directly used by the model developers to address the issue you've raised</p>
|
442 |
</div>
|
443 |
-
|
444 |
</div>
|
445 |
|
446 |
<style>
|
|
|
51 |
]
|
52 |
|
53 |
let personalized_models = [];
|
|
|
|
|
54 |
let systems = ["YouSocial comment toxicity classifier"]; // Only one system for now
|
55 |
let clusters = [];
|
56 |
let clusters_for_tuning = []
|
|
|
70 |
let audit_type;
|
71 |
if (scaffold_method == "fixed" || scaffold_method == "personal" || scaffold_method == "personal_group" || scaffold_method == "personal_test" || scaffold_method == "personal_cluster" || scaffold_method == "topic_train" || scaffold_method == "prompts") {
|
72 |
audit_type = audit_types[1];
|
|
|
73 |
} else {
|
74 |
// No scaffolding mode or tutorial
|
75 |
audit_type = audit_types[0];
|
|
|
109 |
if (!personalized_models.includes(personalized_model)) {
|
110 |
personalized_models.push(personalized_model);
|
111 |
}
|
112 |
+
handleAuditButton();
|
113 |
handleClusterButton(); // re-render cluster results
|
114 |
});
|
115 |
|
|
|
139 |
.then((r) => r.text())
|
140 |
.then(function (r_orig) {
|
141 |
let r = JSON.parse(r_orig);
|
|
|
|
|
142 |
personalized_models = r["personalized_models"];
|
143 |
if (use_group_model) {
|
144 |
let personalized_model_grp = r["personalized_model_grp"];
|
|
|
168 |
async function getAudit() {
|
169 |
let req_params = {
|
170 |
pers_model: personalized_model,
|
|
|
171 |
perf_metric: "avg_diff",
|
172 |
breakdown_sort: "difference",
|
173 |
n_topics: 10,
|
|
|
193 |
let req_params = {
|
194 |
cluster: topic,
|
195 |
topic_df_ids: [],
|
|
|
196 |
pers_model: personalized_model,
|
197 |
example_sort: "descending", // TEMP
|
198 |
comparison_group: "status_quo", // TEMP
|
199 |
search_type: "cluster",
|
200 |
keyword: "",
|
|
|
201 |
error_type: cur_error_type,
|
202 |
use_model: use_model,
|
203 |
scaffold_method: scaffold_method,
|
|
|
215 |
<div>
|
216 |
<div style="margin-top: 30px">
|
217 |
<span class="head_3">Auditing</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
</div>
|
219 |
<div style="width: 80%">
|
220 |
+
{#if personalized_model}
|
221 |
<p>In this section, we'll be auditing the content moderation system. Here, you’ll be aided by a personalized model that will help direct your attention towards potential problem areas in the model’s performance. This model isn’t meant to be perfect, but is designed to help you better focus on areas that need human review.</p>
|
222 |
+
{:else}
|
223 |
+
<p>Please first train your personalized model by following the steps in the "Labeling" tab (click the top left tab above).</p>
|
224 |
+
{/if}
|
225 |
</div>
|
226 |
|
227 |
{#if show_audit_settings}
|
|
|
271 |
</LayoutGrid>
|
272 |
</div>
|
273 |
</div>
|
274 |
+
{/if}
|
275 |
+
{#if personalized_model}
|
276 |
<p>Current model: {personalized_model}</p>
|
277 |
{/if}
|
278 |
</div>
|
279 |
|
280 |
<!-- 1: All topics overview -->
|
281 |
+
{#if personalized_model}
|
282 |
{#if audit_type == audit_types[0]}
|
283 |
<div class="audit_section">
|
284 |
<div class="head_5">Overview of all topics</div>
|
|
|
432 |
<div class="head_5">Finalize your current report</div>
|
433 |
<p>Finally, review the report you've generated on the side panel and provide a brief summary of the problem you see. You may also list suggestions or insights into addressing this problem if you have ideas. This report will be directly used by the model developers to address the issue you've raised</p>
|
434 |
</div>
|
435 |
+
{/if}
|
436 |
</div>
|
437 |
|
438 |
<style>
|
indie_label_svelte/src/CommentTable.svelte
CHANGED
@@ -5,6 +5,8 @@
|
|
5 |
import DataTable, { Head, Body, Row, Cell } from "@smui/data-table";
|
6 |
import LinearProgress from '@smui/linear-progress';
|
7 |
|
|
|
|
|
8 |
export let mode;
|
9 |
export let model_name;
|
10 |
export let cur_user;
|
@@ -13,6 +15,7 @@
|
|
13 |
let promise = Promise.resolve(null);
|
14 |
let n_complete_ratings;
|
15 |
let n_unsure_ratings;
|
|
|
16 |
|
17 |
function getCommentsToLabel(cur_mode, n) {
|
18 |
if (cur_mode == "train") {
|
@@ -41,6 +44,7 @@
|
|
41 |
}
|
42 |
|
43 |
function handleTrainModelButton() {
|
|
|
44 |
promise = getModel("train");
|
45 |
}
|
46 |
|
@@ -88,7 +92,7 @@
|
|
88 |
const text = await response.text();
|
89 |
const data = JSON.parse(text);
|
90 |
to_label = data["ratings_prev"];
|
91 |
-
|
92 |
return data;
|
93 |
}
|
94 |
</script>
|
@@ -214,12 +218,14 @@
|
|
214 |
{/key}
|
215 |
|
216 |
<div class="spacing_vert_40">
|
217 |
-
<Button on:click={handleTrainModelButton} variant="outlined"
|
218 |
<Label>Train Model</Label>
|
219 |
</Button>
|
|
|
220 |
<Button on:click={getCompleteRatings} variant="outlined">
|
221 |
<Label>Get Number of Comments Labeled</Label>
|
222 |
</Button>
|
|
|
223 |
<Button on:click={() => handleLoadCommentsButton(5)} variant="outlined">
|
224 |
<Label>Fetch More Comments To Label</Label>
|
225 |
</Button>
|
|
|
5 |
import DataTable, { Head, Body, Row, Cell } from "@smui/data-table";
|
6 |
import LinearProgress from '@smui/linear-progress';
|
7 |
|
8 |
+
import { model_chosen } from './stores/cur_model_store.js';
|
9 |
+
|
10 |
export let mode;
|
11 |
export let model_name;
|
12 |
export let cur_user;
|
|
|
15 |
let promise = Promise.resolve(null);
|
16 |
let n_complete_ratings;
|
17 |
let n_unsure_ratings;
|
18 |
+
let show_comments_labeled_count = false;
|
19 |
|
20 |
function getCommentsToLabel(cur_mode, n) {
|
21 |
if (cur_mode == "train") {
|
|
|
44 |
}
|
45 |
|
46 |
function handleTrainModelButton() {
|
47 |
+
getCompleteRatings();
|
48 |
promise = getModel("train");
|
49 |
}
|
50 |
|
|
|
92 |
const text = await response.text();
|
93 |
const data = JSON.parse(text);
|
94 |
to_label = data["ratings_prev"];
|
95 |
+
model_chosen.update((value) => model_name);
|
96 |
return data;
|
97 |
}
|
98 |
</script>
|
|
|
218 |
{/key}
|
219 |
|
220 |
<div class="spacing_vert_40">
|
221 |
+
<Button on:click={handleTrainModelButton} variant="outlined">
|
222 |
<Label>Train Model</Label>
|
223 |
</Button>
|
224 |
+
{#if show_comments_labeled_count}
|
225 |
<Button on:click={getCompleteRatings} variant="outlined">
|
226 |
<Label>Get Number of Comments Labeled</Label>
|
227 |
</Button>
|
228 |
+
{/if}
|
229 |
<Button on:click={() => handleLoadCommentsButton(5)} variant="outlined">
|
230 |
<Label>Fetch More Comments To Label</Label>
|
231 |
</Button>
|
indie_label_svelte/src/Hunch.svelte
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
<script lang="ts">
|
2 |
import { onMount } from "svelte";
|
3 |
-
import IterativeClustering from "./IterativeClustering.svelte";
|
4 |
import Button, { Label } from "@smui/button";
|
5 |
import Textfield from '@smui/textfield';
|
6 |
-
import LinearProgress from "@smui/linear-progress";
|
7 |
|
8 |
export let ind;
|
9 |
export let hunch;
|
@@ -32,7 +30,6 @@
|
|
32 |
|
33 |
<div>
|
34 |
<div>
|
35 |
-
<!-- <h6>Hunch {ind + 1}</h6> -->
|
36 |
<h6>Topic:</h6>
|
37 |
{topic}
|
38 |
</div>
|
@@ -46,13 +43,6 @@
|
|
46 |
label="My current hunch is that..."
|
47 |
>
|
48 |
</Textfield>
|
49 |
-
<!-- <Button
|
50 |
-
on:click={handleTestOnExamples}
|
51 |
-
class="button_float_right spacing_vert"
|
52 |
-
variant="outlined"
|
53 |
-
>
|
54 |
-
<Label>Test on examples</Label>
|
55 |
-
</Button> -->
|
56 |
</div>
|
57 |
|
58 |
<div class="spacing_vert">
|
@@ -63,23 +53,7 @@
|
|
63 |
<Label>Submit</Label>
|
64 |
</Button>
|
65 |
</div>
|
66 |
-
|
67 |
-
<!-- {#await example_block}
|
68 |
-
<div class="app_loading">
|
69 |
-
<LinearProgress indeterminate />
|
70 |
-
</div>
|
71 |
-
{:then} -->
|
72 |
-
<!-- {#if example_block}
|
73 |
-
<IterativeClustering clusters={clusters} ind={ind + 1} personalized_model={model} />
|
74 |
-
{/if} -->
|
75 |
-
<!-- {:catch error}
|
76 |
-
<p style="color: red">{error.message}</p>
|
77 |
-
{/await} -->
|
78 |
</div>
|
79 |
|
80 |
<style>
|
81 |
-
/* * {
|
82 |
-
z-index: 11;
|
83 |
-
overflow-x: hidden;
|
84 |
-
} */
|
85 |
</style>
|
|
|
1 |
<script lang="ts">
|
2 |
import { onMount } from "svelte";
|
|
|
3 |
import Button, { Label } from "@smui/button";
|
4 |
import Textfield from '@smui/textfield';
|
|
|
5 |
|
6 |
export let ind;
|
7 |
export let hunch;
|
|
|
30 |
|
31 |
<div>
|
32 |
<div>
|
|
|
33 |
<h6>Topic:</h6>
|
34 |
{topic}
|
35 |
</div>
|
|
|
43 |
label="My current hunch is that..."
|
44 |
>
|
45 |
</Textfield>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
</div>
|
47 |
|
48 |
<div class="spacing_vert">
|
|
|
53 |
<Label>Submit</Label>
|
54 |
</Button>
|
55 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
</div>
|
57 |
|
58 |
<style>
|
|
|
|
|
|
|
|
|
59 |
</style>
|
indie_label_svelte/src/HypothesisPanel.svelte
CHANGED
@@ -225,6 +225,7 @@
|
|
225 |
<Button
|
226 |
on:click={() => (open = !open)}
|
227 |
color="primary"
|
|
|
228 |
style="float: right; padding: 10px; margin-right: 10px;"
|
229 |
>
|
230 |
{#if open}
|
@@ -239,6 +240,11 @@
|
|
239 |
</div>
|
240 |
</div>
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
<div class="panel_contents">
|
243 |
<!-- Drawer -->
|
244 |
{#await promise}
|
@@ -491,7 +497,7 @@
|
|
491 |
</div>
|
492 |
</div>
|
493 |
</div>
|
494 |
-
|
495 |
<!-- TEMP -->
|
496 |
<!-- {#key model}
|
497 |
<div>Model: {model}</div>
|
|
|
225 |
<Button
|
226 |
on:click={() => (open = !open)}
|
227 |
color="primary"
|
228 |
+
disabled={model == null}
|
229 |
style="float: right; padding: 10px; margin-right: 10px;"
|
230 |
>
|
231 |
{#if open}
|
|
|
240 |
</div>
|
241 |
</div>
|
242 |
|
243 |
+
{#if model == null}
|
244 |
+
<div class="panel_contents">
|
245 |
+
<p>You can start to author audit reports in this panel after you've trained your personalized model in the "Labeling" tab.</p>
|
246 |
+
</div>
|
247 |
+
{:else}
|
248 |
<div class="panel_contents">
|
249 |
<!-- Drawer -->
|
250 |
{#await promise}
|
|
|
497 |
</div>
|
498 |
</div>
|
499 |
</div>
|
500 |
+
{/if}
|
501 |
<!-- TEMP -->
|
502 |
<!-- {#key model}
|
503 |
<div>Model: {model}</div>
|
indie_label_svelte/src/IterativeClustering.svelte
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
<script>
|
2 |
-
import Section from "./Section.svelte";
|
3 |
-
import ClusterResults from "./ClusterResults.svelte";
|
4 |
-
import Button, { Label } from "@smui/button";
|
5 |
-
import Textfield from "@smui/textfield";
|
6 |
-
import LayoutGrid, { Cell } from "@smui/layout-grid";
|
7 |
-
import LinearProgress from "@smui/linear-progress";
|
8 |
-
import Chip, { Set, Text } from '@smui/chips';
|
9 |
-
|
10 |
-
export let clusters;
|
11 |
-
export let personalized_model;
|
12 |
-
export let evidence;
|
13 |
-
export let width_pct = 80;
|
14 |
-
|
15 |
-
let topic_df_ids = [];
|
16 |
-
let promise_iter_cluster = Promise.resolve(null);
|
17 |
-
let keyword = null;
|
18 |
-
let n_neighbors = null;
|
19 |
-
let cur_iter_cluster = null;
|
20 |
-
let history = [];
|
21 |
-
|
22 |
-
async function getIterCluster(search_type) {
|
23 |
-
let req_params = {
|
24 |
-
cluster: cur_iter_cluster,
|
25 |
-
topic_df_ids: topic_df_ids,
|
26 |
-
n_examples: 500, // TEMP
|
27 |
-
pers_model: personalized_model,
|
28 |
-
example_sort: "descending", // TEMP
|
29 |
-
comparison_group: "status_quo", // TEMP
|
30 |
-
search_type: search_type,
|
31 |
-
keyword: keyword,
|
32 |
-
n_neighbors: n_neighbors,
|
33 |
-
};
|
34 |
-
console.log("topic_df_ids", topic_df_ids);
|
35 |
-
let params = new URLSearchParams(req_params).toString();
|
36 |
-
const response = await fetch("./get_cluster_results?" + params);
|
37 |
-
const text = await response.text();
|
38 |
-
const data = JSON.parse(text);
|
39 |
-
// if (data["cluster_comments"] == null) {
|
40 |
-
// return false
|
41 |
-
// }
|
42 |
-
topic_df_ids = data["topic_df_ids"];
|
43 |
-
return data;
|
44 |
-
}
|
45 |
-
|
46 |
-
function findCluster() {
|
47 |
-
promise_iter_cluster = getIterCluster("cluster");
|
48 |
-
history = history.concat("bulk-add cluster: " + cur_iter_cluster);
|
49 |
-
}
|
50 |
-
|
51 |
-
function findNeighbors() {
|
52 |
-
promise_iter_cluster = getIterCluster("neighbors");
|
53 |
-
history = history.concat("find " + n_neighbors + " neighbors");
|
54 |
-
}
|
55 |
-
|
56 |
-
function findKeywords() {
|
57 |
-
promise_iter_cluster = getIterCluster("keyword");
|
58 |
-
history = history.concat("keyword search: " + keyword);
|
59 |
-
}
|
60 |
-
</script>
|
61 |
-
|
62 |
-
<div>
|
63 |
-
<div>
|
64 |
-
<!-- <h6>Hunch {ind} examples</h6> -->
|
65 |
-
<div>
|
66 |
-
<h6>Search Settings</h6>
|
67 |
-
<!-- Start with cluster -->
|
68 |
-
<!-- <div class="">
|
69 |
-
<Section
|
70 |
-
section_id="iter_cluster"
|
71 |
-
section_title="Bulk-add cluster"
|
72 |
-
section_opts={clusters}
|
73 |
-
bind:value={cur_iter_cluster}
|
74 |
-
width_pct={100}
|
75 |
-
/>
|
76 |
-
<Button
|
77 |
-
on:click={findCluster}
|
78 |
-
variant="outlined"
|
79 |
-
class="button_float_right"
|
80 |
-
disabled={cur_iter_cluster == null}
|
81 |
-
>
|
82 |
-
<Label>Search</Label>
|
83 |
-
</Button>
|
84 |
-
</div> -->
|
85 |
-
|
86 |
-
<!-- Manual keyword -->
|
87 |
-
<div class="spacing_vert">
|
88 |
-
<Textfield
|
89 |
-
bind:value={keyword}
|
90 |
-
label="Keyword search"
|
91 |
-
variant="outlined"
|
92 |
-
style="width: {width_pct}%"
|
93 |
-
/>
|
94 |
-
<Button
|
95 |
-
on:click={findKeywords}
|
96 |
-
variant="outlined"
|
97 |
-
class="button_float_right spacing_vert"
|
98 |
-
disabled={keyword == null}
|
99 |
-
>
|
100 |
-
<Label>Search</Label>
|
101 |
-
</Button>
|
102 |
-
</div>
|
103 |
-
|
104 |
-
<!-- Find neighbors of current set -->
|
105 |
-
<div class="spacing_vert">
|
106 |
-
<Textfield
|
107 |
-
bind:value={n_neighbors}
|
108 |
-
label="Number of neighbors to retrieve"
|
109 |
-
type="number"
|
110 |
-
min="1"
|
111 |
-
max="50"
|
112 |
-
variant="outlined"
|
113 |
-
style="width: {width_pct}%"
|
114 |
-
/>
|
115 |
-
<Button
|
116 |
-
on:click={findNeighbors}
|
117 |
-
variant="outlined"
|
118 |
-
class="button_float_right spacing_vert"
|
119 |
-
disabled={n_neighbors == null}
|
120 |
-
>
|
121 |
-
<Label>Search</Label>
|
122 |
-
</Button>
|
123 |
-
</div>
|
124 |
-
</div>
|
125 |
-
</div>
|
126 |
-
|
127 |
-
{#await promise_iter_cluster}
|
128 |
-
<div class="app_loading" style="width: {width_pct}%">
|
129 |
-
<LinearProgress indeterminate />
|
130 |
-
</div>
|
131 |
-
{:then iter_cluster_results}
|
132 |
-
{#if iter_cluster_results}
|
133 |
-
{#if history.length > 0}
|
134 |
-
<div class="bold" style="padding-top:40px;">Search History</div>
|
135 |
-
<Set chips={history} let:chip choice>
|
136 |
-
<Chip {chip}>
|
137 |
-
<Text>{chip}</Text>
|
138 |
-
</Chip>
|
139 |
-
</Set>
|
140 |
-
{/if}
|
141 |
-
{#if iter_cluster_results.cluster_comments != null}
|
142 |
-
<ClusterResults
|
143 |
-
cluster={""}
|
144 |
-
clusters={clusters}
|
145 |
-
model={personalized_model}
|
146 |
-
data={iter_cluster_results}
|
147 |
-
show_vis={false}
|
148 |
-
table_width_pct={80}
|
149 |
-
bind:evidence={evidence}
|
150 |
-
on:change
|
151 |
-
/>
|
152 |
-
{:else}
|
153 |
-
<div class="bold" style="padding-top:40px;">
|
154 |
-
No results found
|
155 |
-
</div>
|
156 |
-
{/if}
|
157 |
-
{/if}
|
158 |
-
{:catch error}
|
159 |
-
<p style="color: red">{error.message}</p>
|
160 |
-
{/await}
|
161 |
-
</div>
|
162 |
-
|
163 |
-
<style>
|
164 |
-
</style>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indie_label_svelte/src/KeywordSearch.svelte
CHANGED
@@ -17,7 +17,6 @@
|
|
17 |
let topic_df_ids = [];
|
18 |
let promise_iter_cluster = Promise.resolve(null);
|
19 |
let keyword = null;
|
20 |
-
let n_neighbors = null;
|
21 |
let cur_iter_cluster = null;
|
22 |
let history = [];
|
23 |
|
@@ -30,13 +29,11 @@
|
|
30 |
let req_params = {
|
31 |
cluster: cur_iter_cluster,
|
32 |
topic_df_ids: topic_df_ids,
|
33 |
-
n_examples: 500, // TEMP
|
34 |
pers_model: personalized_model,
|
35 |
example_sort: "descending", // TEMP
|
36 |
comparison_group: "status_quo", // TEMP
|
37 |
search_type: search_type,
|
38 |
keyword: keyword,
|
39 |
-
n_neighbors: n_neighbors,
|
40 |
error_type: cur_error_type,
|
41 |
};
|
42 |
console.log("topic_df_ids", topic_df_ids);
|
|
|
17 |
let topic_df_ids = [];
|
18 |
let promise_iter_cluster = Promise.resolve(null);
|
19 |
let keyword = null;
|
|
|
20 |
let cur_iter_cluster = null;
|
21 |
let history = [];
|
22 |
|
|
|
29 |
let req_params = {
|
30 |
cluster: cur_iter_cluster,
|
31 |
topic_df_ids: topic_df_ids,
|
|
|
32 |
pers_model: personalized_model,
|
33 |
example_sort: "descending", // TEMP
|
34 |
comparison_group: "status_quo", // TEMP
|
35 |
search_type: search_type,
|
36 |
keyword: keyword,
|
|
|
37 |
error_type: cur_error_type,
|
38 |
};
|
39 |
console.log("topic_df_ids", topic_df_ids);
|
indie_label_svelte/src/Labeling.svelte
CHANGED
@@ -17,7 +17,7 @@
|
|
17 |
let label_modes = [
|
18 |
"Create a new model",
|
19 |
"Edit an existing model",
|
20 |
-
"Tune your model for a topic area",
|
21 |
// "Set up a group-based model",
|
22 |
];
|
23 |
|
@@ -33,6 +33,7 @@
|
|
33 |
} else if (req_label_mode == 1) {
|
34 |
label_mode = label_modes[1];
|
35 |
} else if (req_label_mode == 2) {
|
|
|
36 |
label_mode = label_modes[2];
|
37 |
} else if (req_label_mode == 3) {
|
38 |
// Unused; previous group-based mode
|
|
|
17 |
let label_modes = [
|
18 |
"Create a new model",
|
19 |
"Edit an existing model",
|
20 |
+
// "Tune your model for a topic area",
|
21 |
// "Set up a group-based model",
|
22 |
];
|
23 |
|
|
|
33 |
} else if (req_label_mode == 1) {
|
34 |
label_mode = label_modes[1];
|
35 |
} else if (req_label_mode == 2) {
|
36 |
+
// Unused; previous topic-based mode
|
37 |
label_mode = label_modes[2];
|
38 |
} else if (req_label_mode == 3) {
|
39 |
// Unused; previous group-based mode
|
server.py
CHANGED
@@ -37,7 +37,6 @@ def home(path):
|
|
37 |
|
38 |
########################################
|
39 |
# ROUTE: /AUDIT_SETTINGS
|
40 |
-
comments_grouped_full_topic_cat = pd.read_pickle("data/comments_grouped_full_topic_cat2_persp.pkl")
|
41 |
|
42 |
@app.route("/audit_settings")
|
43 |
def audit_settings(debug=DEBUG):
|
@@ -47,13 +46,10 @@ def audit_settings(debug=DEBUG):
|
|
47 |
|
48 |
# Assign user ID if none is provided (default case)
|
49 |
if user == "null":
|
50 |
-
|
51 |
-
|
52 |
-
else:
|
53 |
-
# Generate random two-word user ID
|
54 |
-
user = fw.generate(2, separator="_")
|
55 |
|
56 |
-
user_models = utils.
|
57 |
grp_models = [m for m in user_models if m.startswith(f"model_{user}_group_")]
|
58 |
|
59 |
clusters = utils.get_unique_topics()
|
@@ -76,19 +72,6 @@ def audit_settings(debug=DEBUG):
|
|
76 |
"options": [{"value": i, "text": cluster} for i, cluster in enumerate(clusters)],
|
77 |
},]
|
78 |
|
79 |
-
if scaffold_method == "personal_cluster":
|
80 |
-
cluster_model = user_models[0]
|
81 |
-
personal_cluster_file = f"./data/personal_cluster_dfs/{cluster_model}.pkl"
|
82 |
-
if os.path.isfile(personal_cluster_file) and cluster_model != "":
|
83 |
-
print("audit_settings", personal_cluster_file, cluster_model)
|
84 |
-
topics_under_top, topics_over_top = utils.get_personal_clusters(cluster_model)
|
85 |
-
pers_cluster = topics_under_top + topics_over_top
|
86 |
-
pers_cluster_options = {
|
87 |
-
"label": "Personalized clusters",
|
88 |
-
"options": [{"value": i, "text": cluster} for i, cluster in enumerate(pers_cluster)],
|
89 |
-
}
|
90 |
-
clusters_options.insert(0, pers_cluster_options)
|
91 |
-
|
92 |
clusters_for_tuning = utils.get_large_clusters(min_n=150)
|
93 |
clusters_for_tuning_options = [{"value": i, "text": cluster} for i, cluster in enumerate(clusters_for_tuning)] # Format for Svelecte UI element
|
94 |
|
@@ -96,7 +79,6 @@ def audit_settings(debug=DEBUG):
|
|
96 |
"personalized_models": user_models,
|
97 |
"personalized_model_grp": grp_models,
|
98 |
"perf_metrics": ["Average rating difference", "Mean Absolute Error (MAE)", "Root Mean Squared Error (RMSE)", "Mean Squared Error (MSE)"],
|
99 |
-
"breakdown_categories": ['Topic', 'Toxicity Category', 'Toxicity Severity'],
|
100 |
"clusters": clusters_options,
|
101 |
"clusters_for_tuning": clusters_for_tuning_options,
|
102 |
"user": user,
|
@@ -109,30 +91,21 @@ def audit_settings(debug=DEBUG):
|
|
109 |
@app.route("/get_audit")
|
110 |
def get_audit():
|
111 |
pers_model = request.args.get("pers_model")
|
112 |
-
perf_metric = request.args.get("perf_metric")
|
113 |
-
breakdown_axis = request.args.get("breakdown_axis")
|
114 |
-
breakdown_sort = request.args.get("breakdown_sort")
|
115 |
-
n_topics = int(request.args.get("n_topics"))
|
116 |
error_type = request.args.get("error_type")
|
117 |
cur_user = request.args.get("cur_user")
|
118 |
topic_vis_method = request.args.get("topic_vis_method")
|
119 |
if topic_vis_method == "null":
|
120 |
topic_vis_method = "median"
|
121 |
|
122 |
-
if
|
123 |
-
|
124 |
-
elif breakdown_sort == "default":
|
125 |
-
sort_class_plot = False
|
126 |
else:
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
breakdown_axis=breakdown_axis,
|
134 |
-
topic_vis_method=topic_vis_method,
|
135 |
-
)
|
136 |
|
137 |
results = {
|
138 |
"overall_perf": overall_perf,
|
@@ -142,60 +115,32 @@ def get_audit():
|
|
142 |
########################################
|
143 |
# ROUTE: /GET_CLUSTER_RESULTS
|
144 |
@app.route("/get_cluster_results")
|
145 |
-
def get_cluster_results():
|
146 |
pers_model = request.args.get("pers_model")
|
147 |
-
n_examples = int(request.args.get("n_examples"))
|
148 |
cluster = request.args.get("cluster")
|
149 |
-
example_sort = request.args.get("example_sort")
|
150 |
-
comparison_group = request.args.get("comparison_group")
|
151 |
topic_df_ids = request.args.getlist("topic_df_ids")
|
152 |
topic_df_ids = [int(val) for val in topic_df_ids[0].split(",") if val != ""]
|
153 |
search_type = request.args.get("search_type")
|
154 |
keyword = request.args.get("keyword")
|
155 |
-
n_neighbors = request.args.get("n_neighbors")
|
156 |
-
if n_neighbors != "null":
|
157 |
-
n_neighbors = int(n_neighbors)
|
158 |
-
neighbor_threshold = 0.6
|
159 |
error_type = request.args.get("error_type")
|
160 |
use_model = request.args.get("use_model") == "true"
|
161 |
-
scaffold_method = request.args.get("scaffold_method")
|
162 |
-
|
163 |
-
|
164 |
-
# If user has a tuned model for this cluster, use that
|
165 |
-
cluster_model_file = f"./data/trained_models/{pers_model}_{cluster}.pkl"
|
166 |
-
if os.path.isfile(cluster_model_file):
|
167 |
-
pers_model = f"{pers_model}_{cluster}"
|
168 |
-
|
169 |
-
print(f"get_cluster_results using model {pers_model}")
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
sort_ascending = True if example_sort == "ascending" else False
|
174 |
|
|
|
175 |
topic_df = None
|
176 |
-
|
177 |
-
|
178 |
-
if
|
179 |
-
#
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# Regular handling
|
185 |
-
with open(f"data/preds_dfs/{pers_model}.pkl", "rb") as f:
|
186 |
-
topic_df = pickle.load(f)
|
187 |
-
if search_type == "cluster":
|
188 |
-
# Display examples with comment, your pred, and other users' pred
|
189 |
-
topic_df = topic_df[(topic_df["topic"] == cluster) | (topic_df["item_id"].isin(topic_df_ids))]
|
190 |
-
|
191 |
-
elif search_type == "neighbors":
|
192 |
-
neighbor_ids = utils.get_match(topic_df_ids, K=n_neighbors, threshold=neighbor_threshold, debug=False)
|
193 |
-
topic_df = topic_df[(topic_df["item_id"].isin(neighbor_ids)) | (topic_df["item_id"].isin(topic_df_ids))]
|
194 |
-
elif search_type == "keyword":
|
195 |
-
topic_df = topic_df[(topic_df["comment"].str.contains(keyword, case=False, regex=False)) | (topic_df["item_id"].isin(topic_df_ids))]
|
196 |
-
|
197 |
topic_df = topic_df.drop_duplicates()
|
198 |
-
|
|
|
199 |
|
200 |
# Handle empty results
|
201 |
if len(topic_df) == 0:
|
@@ -216,24 +161,20 @@ def get_cluster_results():
|
|
216 |
|
217 |
topic_df_ids = topic_df["item_id"].unique().tolist()
|
218 |
|
219 |
-
|
|
|
|
|
220 |
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
|
221 |
else:
|
222 |
-
#
|
223 |
-
|
224 |
-
if use_model:
|
225 |
-
# Display results with the model as a reference point
|
226 |
-
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df_mod, error_type=error_type, n_comments=500)
|
227 |
-
else:
|
228 |
-
# Display results without a model
|
229 |
-
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(topic_df_mod, n_comments=500)
|
230 |
|
231 |
-
cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type,
|
232 |
|
233 |
results = {
|
234 |
"topic_df_ids": topic_df_ids,
|
235 |
"cluster_overview_plot_json": json.loads(cluster_overview_plot_json),
|
236 |
-
"cluster_comments": cluster_comments,
|
237 |
}
|
238 |
return json.dumps(results)
|
239 |
|
@@ -280,7 +221,6 @@ def get_group_model():
|
|
280 |
grp_ids = grp_df["worker_id"].tolist()
|
281 |
|
282 |
ratings_grp = utils.get_grp_model_labels(
|
283 |
-
comments_df=comments_grouped_full_topic_cat,
|
284 |
n_label_per_bin=BIN_DISTRIB,
|
285 |
score_bins=SCORE_BINS,
|
286 |
grp_ids=grp_ids,
|
@@ -322,7 +262,7 @@ def get_labeling():
|
|
322 |
model_name_suggestion = f"my_model"
|
323 |
|
324 |
context = {
|
325 |
-
"personalized_models": utils.
|
326 |
"model_name_suggestion": model_name_suggestion,
|
327 |
"clusters_for_tuning": clusters_for_tuning_options,
|
328 |
}
|
@@ -330,15 +270,16 @@ def get_labeling():
|
|
330 |
|
331 |
########################################
|
332 |
# ROUTE: /GET_COMMENTS_TO_LABEL
|
333 |
-
|
334 |
-
BIN_DISTRIB = [
|
|
|
|
|
335 |
SCORE_BINS = [(0.0, 0.5), (0.5, 1.5), (1.5, 2.5), (2.5, 3.5), (3.5, 4.01)]
|
336 |
@app.route("/get_comments_to_label")
|
337 |
def get_comments_to_label():
|
338 |
n = int(request.args.get("n"))
|
339 |
# Fetch examples to label
|
340 |
to_label_ids = utils.create_example_sets(
|
341 |
-
comments_df=comments_grouped_full_topic_cat,
|
342 |
n_label_per_bin=BIN_DISTRIB,
|
343 |
score_bins=SCORE_BINS,
|
344 |
keyword=None
|
@@ -355,14 +296,11 @@ def get_comments_to_label():
|
|
355 |
|
356 |
########################################
|
357 |
# ROUTE: /GET_COMMENTS_TO_LABEL_TOPIC
|
358 |
-
N_LABEL_PER_BIN_TOPIC = 2 # 2 * 5 = 10 comments
|
359 |
@app.route("/get_comments_to_label_topic")
|
360 |
def get_comments_to_label_topic():
|
361 |
# Fetch examples to label
|
362 |
topic = request.args.get("topic")
|
363 |
to_label_ids = utils.create_example_sets(
|
364 |
-
comments_df=comments_grouped_full_topic_cat,
|
365 |
-
# n_label_per_bin=N_LABEL_PER_BIN_TOPIC,
|
366 |
n_label_per_bin=BIN_DISTRIB,
|
367 |
score_bins=SCORE_BINS,
|
368 |
keyword=None,
|
@@ -397,10 +335,7 @@ def get_personalized_model():
|
|
397 |
# Handle existing or new model cases
|
398 |
if mode == "view":
|
399 |
# Fetch prior model performance
|
400 |
-
|
401 |
-
raise Exception(f"Model {model_name} does not exist")
|
402 |
-
else:
|
403 |
-
mae, mse, rmse, avg_diff, ratings_prev = utils.fetch_existing_data(model_name, last_label_i)
|
404 |
|
405 |
elif mode == "train":
|
406 |
# Train model and cache predictions using new labels
|
@@ -490,8 +425,6 @@ def get_reports():
|
|
490 |
reports = get_fixed_scaffold()
|
491 |
elif (scaffold_method == "personal" or scaffold_method == "personal_group" or scaffold_method == "personal_test"):
|
492 |
reports = get_personal_scaffold(model, topic_vis_method)
|
493 |
-
elif (scaffold_method == "personal_cluster"):
|
494 |
-
reports = get_personal_cluster_scaffold(model)
|
495 |
elif scaffold_method == "prompts":
|
496 |
reports = get_prompts_scaffold()
|
497 |
elif scaffold_method == "tutorial":
|
@@ -576,21 +509,11 @@ def get_tutorial_scaffold():
|
|
576 |
},
|
577 |
]
|
578 |
|
579 |
-
def get_personal_cluster_scaffold(model):
|
580 |
-
topics_under_top, topics_over_top = utils.get_personal_clusters(model)
|
581 |
-
|
582 |
-
report_under = [get_empty_report(topic, "System is under-sensitive") for topic in topics_under_top]
|
583 |
-
|
584 |
-
report_over = [get_empty_report(topic, "System is over-sensitive") for topic in topics_over_top]
|
585 |
-
reports = (report_under + report_over)
|
586 |
-
random.shuffle(reports)
|
587 |
-
return reports
|
588 |
-
|
589 |
def get_topic_errors(df, topic_vis_method, threshold=2):
|
590 |
-
topics = df["
|
591 |
topic_errors = {}
|
592 |
for topic in topics:
|
593 |
-
t_df = df[df["
|
594 |
y_true = t_df["pred"].to_numpy()
|
595 |
y_pred = t_df["rating"].to_numpy()
|
596 |
if topic_vis_method == "mae":
|
@@ -627,27 +550,28 @@ def get_personal_scaffold(model, topic_vis_method, n_topics=200, n=5):
|
|
627 |
# Get topics with greatest amount of error
|
628 |
with open(f"./data/preds_dfs/{model}.pkl", "rb") as f:
|
629 |
preds_df = pickle.load(f)
|
630 |
-
|
|
|
631 |
preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
|
632 |
-
preds_df_mod = preds_df_mod[preds_df_mod["
|
633 |
|
634 |
if topic_vis_method == "median":
|
635 |
-
df = preds_df_mod.groupby(["
|
636 |
elif topic_vis_method == "mean":
|
637 |
-
df = preds_df_mod.groupby(["
|
638 |
elif topic_vis_method == "fp_fn":
|
639 |
for error_type in ["fn_proportion", "fp_proportion"]:
|
640 |
topic_errors = get_topic_errors(preds_df_mod, error_type)
|
641 |
-
preds_df_mod[error_type] = [topic_errors[topic] for topic in preds_df_mod["
|
642 |
-
df = preds_df_mod.groupby(["
|
643 |
else:
|
644 |
# Get error for each topic
|
645 |
topic_errors = get_topic_errors(preds_df_mod, topic_vis_method)
|
646 |
-
preds_df_mod[topic_vis_method] = [topic_errors[topic] for topic in preds_df_mod["
|
647 |
-
df = preds_df_mod.groupby(["
|
648 |
|
649 |
# Get system error
|
650 |
-
df = df[(df["
|
651 |
|
652 |
if topic_vis_method == "median" or topic_vis_method == "mean":
|
653 |
df["error_magnitude"] = [utils.get_error_magnitude(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
|
@@ -655,31 +579,30 @@ def get_personal_scaffold(model, topic_vis_method, n_topics=200, n=5):
|
|
655 |
|
656 |
df_under = df[df["error_type"] == "System is under-sensitive"]
|
657 |
df_under = df_under.sort_values(by=["error_magnitude"], ascending=False).head(n) # surface largest errors first
|
658 |
-
report_under = [get_empty_report(row["
|
659 |
|
660 |
df_over = df[df["error_type"] == "System is over-sensitive"]
|
661 |
df_over = df_over.sort_values(by=["error_magnitude"], ascending=False).head(n) # surface largest errors first
|
662 |
-
report_over = [get_empty_report(row["
|
663 |
|
664 |
# Set up reports
|
665 |
-
# return [get_empty_report(row["topic_"], row["error_type"]) for index, row in df.iterrows()]
|
666 |
reports = (report_under + report_over)
|
667 |
random.shuffle(reports)
|
668 |
elif topic_vis_method == "fp_fn":
|
669 |
df_under = df.sort_values(by=["fn_proportion"], ascending=False).head(n)
|
670 |
df_under = df_under[df_under["fn_proportion"] > 0]
|
671 |
-
report_under = [get_empty_report(row["
|
672 |
|
673 |
df_over = df.sort_values(by=["fp_proportion"], ascending=False).head(n)
|
674 |
df_over = df_over[df_over["fp_proportion"] > 0]
|
675 |
-
report_over = [get_empty_report(row["
|
676 |
|
677 |
reports = (report_under + report_over)
|
678 |
random.shuffle(reports)
|
679 |
else:
|
680 |
df = df.sort_values(by=[topic_vis_method], ascending=False).head(n * 2)
|
681 |
df["error_type"] = [utils.get_error_type_radio(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
|
682 |
-
reports = [get_empty_report(row["
|
683 |
|
684 |
return reports
|
685 |
|
@@ -750,11 +673,7 @@ def get_explore_examples():
|
|
750 |
n_examples = int(request.args.get("n_examples"))
|
751 |
|
752 |
# Get sample of examples
|
753 |
-
df = utils.
|
754 |
-
|
755 |
-
df["system_decision"] = [utils.get_decision(rating, threshold) for rating in df["rating"].tolist()]
|
756 |
-
df["system_color"] = [utils.get_user_color(sys, threshold) for sys in df["rating"].tolist()] # get cell colors
|
757 |
-
|
758 |
ex_json = df.to_json(orient="records")
|
759 |
|
760 |
results = {
|
|
|
37 |
|
38 |
########################################
|
39 |
# ROUTE: /AUDIT_SETTINGS
|
|
|
40 |
|
41 |
@app.route("/audit_settings")
|
42 |
def audit_settings(debug=DEBUG):
|
|
|
46 |
|
47 |
# Assign user ID if none is provided (default case)
|
48 |
if user == "null":
|
49 |
+
# Generate random two-word user ID
|
50 |
+
user = fw.generate(2, separator="_")
|
|
|
|
|
|
|
51 |
|
52 |
+
user_models = utils.get_user_model_names(user)
|
53 |
grp_models = [m for m in user_models if m.startswith(f"model_{user}_group_")]
|
54 |
|
55 |
clusters = utils.get_unique_topics()
|
|
|
72 |
"options": [{"value": i, "text": cluster} for i, cluster in enumerate(clusters)],
|
73 |
},]
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
clusters_for_tuning = utils.get_large_clusters(min_n=150)
|
76 |
clusters_for_tuning_options = [{"value": i, "text": cluster} for i, cluster in enumerate(clusters_for_tuning)] # Format for Svelecte UI element
|
77 |
|
|
|
79 |
"personalized_models": user_models,
|
80 |
"personalized_model_grp": grp_models,
|
81 |
"perf_metrics": ["Average rating difference", "Mean Absolute Error (MAE)", "Root Mean Squared Error (RMSE)", "Mean Squared Error (MSE)"],
|
|
|
82 |
"clusters": clusters_options,
|
83 |
"clusters_for_tuning": clusters_for_tuning_options,
|
84 |
"user": user,
|
|
|
91 |
@app.route("/get_audit")
|
92 |
def get_audit():
|
93 |
pers_model = request.args.get("pers_model")
|
|
|
|
|
|
|
|
|
94 |
error_type = request.args.get("error_type")
|
95 |
cur_user = request.args.get("cur_user")
|
96 |
topic_vis_method = request.args.get("topic_vis_method")
|
97 |
if topic_vis_method == "null":
|
98 |
topic_vis_method = "median"
|
99 |
|
100 |
+
if pers_model == "" or pers_model == "null" or pers_model == "undefined":
|
101 |
+
overall_perf = None
|
|
|
|
|
102 |
else:
|
103 |
+
overall_perf = utils.show_overall_perf(
|
104 |
+
variant=pers_model,
|
105 |
+
error_type=error_type,
|
106 |
+
cur_user=cur_user,
|
107 |
+
topic_vis_method=topic_vis_method,
|
108 |
+
)
|
|
|
|
|
|
|
109 |
|
110 |
results = {
|
111 |
"overall_perf": overall_perf,
|
|
|
115 |
########################################
|
116 |
# ROUTE: /GET_CLUSTER_RESULTS
|
117 |
@app.route("/get_cluster_results")
|
118 |
+
def get_cluster_results(debug=DEBUG):
|
119 |
pers_model = request.args.get("pers_model")
|
|
|
120 |
cluster = request.args.get("cluster")
|
|
|
|
|
121 |
topic_df_ids = request.args.getlist("topic_df_ids")
|
122 |
topic_df_ids = [int(val) for val in topic_df_ids[0].split(",") if val != ""]
|
123 |
search_type = request.args.get("search_type")
|
124 |
keyword = request.args.get("keyword")
|
|
|
|
|
|
|
|
|
125 |
error_type = request.args.get("error_type")
|
126 |
use_model = request.args.get("use_model") == "true"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
if debug:
|
129 |
+
print(f"get_cluster_results using model {pers_model}")
|
|
|
130 |
|
131 |
+
# Prepare cluster df (topic_df)
|
132 |
topic_df = None
|
133 |
+
with open(f"data/preds_dfs/{pers_model}.pkl", "rb") as f:
|
134 |
+
topic_df = pickle.load(f)
|
135 |
+
if search_type == "cluster":
|
136 |
+
# Display examples with comment, your pred, and other users' pred
|
137 |
+
topic_df = topic_df[(topic_df["topic"] == cluster) | (topic_df["item_id"].isin(topic_df_ids))]
|
138 |
+
elif search_type == "keyword":
|
139 |
+
topic_df = topic_df[(topic_df["comment"].str.contains(keyword, case=False, regex=False)) | (topic_df["item_id"].isin(topic_df_ids))]
|
140 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
topic_df = topic_df.drop_duplicates()
|
142 |
+
if debug:
|
143 |
+
print("len topic_df", len(topic_df))
|
144 |
|
145 |
# Handle empty results
|
146 |
if len(topic_df) == 0:
|
|
|
161 |
|
162 |
topic_df_ids = topic_df["item_id"].unique().tolist()
|
163 |
|
164 |
+
# Prepare overview plot for the cluster
|
165 |
+
if use_model:
|
166 |
+
# Display results with the model as a reference point
|
167 |
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
|
168 |
else:
|
169 |
+
# Display results without a model
|
170 |
+
cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(topic_df, n_comments=500)
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, use_model=use_model) # New version of cluster comment table
|
173 |
|
174 |
results = {
|
175 |
"topic_df_ids": topic_df_ids,
|
176 |
"cluster_overview_plot_json": json.loads(cluster_overview_plot_json),
|
177 |
+
"cluster_comments": cluster_comments.to_json(orient="records"),
|
178 |
}
|
179 |
return json.dumps(results)
|
180 |
|
|
|
221 |
grp_ids = grp_df["worker_id"].tolist()
|
222 |
|
223 |
ratings_grp = utils.get_grp_model_labels(
|
|
|
224 |
n_label_per_bin=BIN_DISTRIB,
|
225 |
score_bins=SCORE_BINS,
|
226 |
grp_ids=grp_ids,
|
|
|
262 |
model_name_suggestion = f"my_model"
|
263 |
|
264 |
context = {
|
265 |
+
"personalized_models": utils.get_user_model_names(user),
|
266 |
"model_name_suggestion": model_name_suggestion,
|
267 |
"clusters_for_tuning": clusters_for_tuning_options,
|
268 |
}
|
|
|
270 |
|
271 |
########################################
|
272 |
# ROUTE: /GET_COMMENTS_TO_LABEL
|
273 |
+
if DEBUG:
|
274 |
+
BIN_DISTRIB = [1, 2, 4, 2, 1] # 10 comments
|
275 |
+
else:
|
276 |
+
BIN_DISTRIB = [2, 4, 8, 4, 2] # 20 comments
|
277 |
SCORE_BINS = [(0.0, 0.5), (0.5, 1.5), (1.5, 2.5), (2.5, 3.5), (3.5, 4.01)]
|
278 |
@app.route("/get_comments_to_label")
|
279 |
def get_comments_to_label():
|
280 |
n = int(request.args.get("n"))
|
281 |
# Fetch examples to label
|
282 |
to_label_ids = utils.create_example_sets(
|
|
|
283 |
n_label_per_bin=BIN_DISTRIB,
|
284 |
score_bins=SCORE_BINS,
|
285 |
keyword=None
|
|
|
296 |
|
297 |
########################################
|
298 |
# ROUTE: /GET_COMMENTS_TO_LABEL_TOPIC
|
|
|
299 |
@app.route("/get_comments_to_label_topic")
|
300 |
def get_comments_to_label_topic():
|
301 |
# Fetch examples to label
|
302 |
topic = request.args.get("topic")
|
303 |
to_label_ids = utils.create_example_sets(
|
|
|
|
|
304 |
n_label_per_bin=BIN_DISTRIB,
|
305 |
score_bins=SCORE_BINS,
|
306 |
keyword=None,
|
|
|
335 |
# Handle existing or new model cases
|
336 |
if mode == "view":
|
337 |
# Fetch prior model performance
|
338 |
+
mae, mse, rmse, avg_diff, ratings_prev = utils.fetch_existing_data(model_name, last_label_i)
|
|
|
|
|
|
|
339 |
|
340 |
elif mode == "train":
|
341 |
# Train model and cache predictions using new labels
|
|
|
425 |
reports = get_fixed_scaffold()
|
426 |
elif (scaffold_method == "personal" or scaffold_method == "personal_group" or scaffold_method == "personal_test"):
|
427 |
reports = get_personal_scaffold(model, topic_vis_method)
|
|
|
|
|
428 |
elif scaffold_method == "prompts":
|
429 |
reports = get_prompts_scaffold()
|
430 |
elif scaffold_method == "tutorial":
|
|
|
509 |
},
|
510 |
]
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
def get_topic_errors(df, topic_vis_method, threshold=2):
|
513 |
+
topics = df["topic"].unique().tolist()
|
514 |
topic_errors = {}
|
515 |
for topic in topics:
|
516 |
+
t_df = df[df["topic"] == topic]
|
517 |
y_true = t_df["pred"].to_numpy()
|
518 |
y_pred = t_df["rating"].to_numpy()
|
519 |
if topic_vis_method == "mae":
|
|
|
550 |
# Get topics with greatest amount of error
|
551 |
with open(f"./data/preds_dfs/{model}.pkl", "rb") as f:
|
552 |
preds_df = pickle.load(f)
|
553 |
+
system_preds_df = utils.get_system_preds_df()
|
554 |
+
preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
|
555 |
preds_df_mod = preds_df_mod[preds_df_mod["user_id"] == "A"].sort_values(by=["item_id"]).reset_index()
|
556 |
+
preds_df_mod = preds_df_mod[preds_df_mod["topic_id"] < n_topics]
|
557 |
|
558 |
if topic_vis_method == "median":
|
559 |
+
df = preds_df_mod.groupby(["topic", "user_id"]).median().reset_index()
|
560 |
elif topic_vis_method == "mean":
|
561 |
+
df = preds_df_mod.groupby(["topic", "user_id"]).mean().reset_index()
|
562 |
elif topic_vis_method == "fp_fn":
|
563 |
for error_type in ["fn_proportion", "fp_proportion"]:
|
564 |
topic_errors = get_topic_errors(preds_df_mod, error_type)
|
565 |
+
preds_df_mod[error_type] = [topic_errors[topic] for topic in preds_df_mod["topic"].tolist()]
|
566 |
+
df = preds_df_mod.groupby(["topic", "user_id"]).mean().reset_index()
|
567 |
else:
|
568 |
# Get error for each topic
|
569 |
topic_errors = get_topic_errors(preds_df_mod, topic_vis_method)
|
570 |
+
preds_df_mod[topic_vis_method] = [topic_errors[topic] for topic in preds_df_mod["topic"].tolist()]
|
571 |
+
df = preds_df_mod.groupby(["topic", "user_id"]).mean().reset_index()
|
572 |
|
573 |
# Get system error
|
574 |
+
df = df[(df["topic"] != "53_maiareficco_kallystas_dyisisitmanila_tractorsazi") & (df["topic"] != "79_idiot_dumb_stupid_dumber")]
|
575 |
|
576 |
if topic_vis_method == "median" or topic_vis_method == "mean":
|
577 |
df["error_magnitude"] = [utils.get_error_magnitude(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
|
|
|
579 |
|
580 |
df_under = df[df["error_type"] == "System is under-sensitive"]
|
581 |
df_under = df_under.sort_values(by=["error_magnitude"], ascending=False).head(n) # surface largest errors first
|
582 |
+
report_under = [get_empty_report(row["topic"], row["error_type"]) for _, row in df_under.iterrows()]
|
583 |
|
584 |
df_over = df[df["error_type"] == "System is over-sensitive"]
|
585 |
df_over = df_over.sort_values(by=["error_magnitude"], ascending=False).head(n) # surface largest errors first
|
586 |
+
report_over = [get_empty_report(row["topic"], row["error_type"]) for _, row in df_over.iterrows()]
|
587 |
|
588 |
# Set up reports
|
|
|
589 |
reports = (report_under + report_over)
|
590 |
random.shuffle(reports)
|
591 |
elif topic_vis_method == "fp_fn":
|
592 |
df_under = df.sort_values(by=["fn_proportion"], ascending=False).head(n)
|
593 |
df_under = df_under[df_under["fn_proportion"] > 0]
|
594 |
+
report_under = [get_empty_report(row["topic"], "System is under-sensitive") for _, row in df_under.iterrows()]
|
595 |
|
596 |
df_over = df.sort_values(by=["fp_proportion"], ascending=False).head(n)
|
597 |
df_over = df_over[df_over["fp_proportion"] > 0]
|
598 |
+
report_over = [get_empty_report(row["topic"], "System is over-sensitive") for _, row in df_over.iterrows()]
|
599 |
|
600 |
reports = (report_under + report_over)
|
601 |
random.shuffle(reports)
|
602 |
else:
|
603 |
df = df.sort_values(by=[topic_vis_method], ascending=False).head(n * 2)
|
604 |
df["error_type"] = [utils.get_error_type_radio(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
|
605 |
+
reports = [get_empty_report(row["topic"], row["error_type"]) for _, row in df.iterrows()]
|
606 |
|
607 |
return reports
|
608 |
|
|
|
673 |
n_examples = int(request.args.get("n_examples"))
|
674 |
|
675 |
# Get sample of examples
|
676 |
+
df = utils.get_explore_df(n_examples, threshold)
|
|
|
|
|
|
|
|
|
677 |
ex_json = df.to_json(orient="records")
|
678 |
|
679 |
results = {
|