Spaces:
Runtime error
Runtime error
Michelle Lam
Refreshes overall plot when model is edited. Uses user ID in preds_df. Exports reports and chart to json. Streamlines performance plot and description.
c55019b
""" | |
UTILS FILE | |
""" | |
import random | |
import json | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import pickle | |
import os | |
import mne | |
from surprise import Dataset, Reader, SVD, accuracy, KNNBasic, KNNWithMeans, KNNWithZScore | |
from surprise.model_selection import train_test_split | |
from sklearn.utils import resample | |
from sklearn.metrics import mean_absolute_error | |
from sklearn.metrics import mean_squared_error | |
from scipy import stats | |
import math | |
import altair as alt | |
import matplotlib.pyplot as plt | |
import time | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from bertopic import BERTopic | |
######################################## | |
# PRE-LOADING | |
YOUR_COLOR = '#6CADFD' | |
OTHER_USERS_COLOR = '#ccc' | |
BINS = [0, 0.5, 1.5, 2.5, 3.5, 4] | |
BIN_LABELS = ['0: Not at all toxic', '1: Slightly toxic', '2: Moderately toxic', '3: Very toxic', '4: Extremely toxic'] | |
TOXIC_THRESHOLD = 2.0 | |
alt.renderers.enable('altair_saver', fmts=['vega-lite', 'png']) | |
# Data-loading | |
module_dir = "./" | |
with open(os.path.join(module_dir, "data/input/ids_to_comments.pkl"), "rb") as f: | |
ids_to_comments = pickle.load(f) | |
with open(os.path.join(module_dir, "data/input/comments_to_ids.pkl"), "rb") as f: | |
comments_to_ids = pickle.load(f) | |
system_preds_df = pd.read_pickle("data/input/system_preds_df.pkl") | |
sys_eval_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/sys_eval_df.pkl")) | |
train_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/train_df.pkl")) | |
train_df_ids = train_df["item_id"].unique().tolist() | |
model_eval_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/model_eval_df.pkl")) | |
ratings_df_full = pd.read_pickle(os.path.join(module_dir, "data/input/ratings_df_full.pkl")) | |
worker_info_df = pd.read_pickle("./data/input/worker_info_df.pkl") | |
topic_ids = system_preds_df.topic_id | |
topics = system_preds_df.topic | |
topic_ids_to_topics = {topic_ids[i]: topics[i] for i in range(len(topic_ids))} | |
topics_to_topic_ids = {topics[i]: topic_ids[i] for i in range(len(topic_ids))} | |
unique_topics_ids = sorted(system_preds_df.topic_id.unique()) | |
unique_topics = [topic_ids_to_topics[topic_id] for topic_id in range(len(topic_ids_to_topics) - 1)] | |
def get_toxic_threshold(): | |
return TOXIC_THRESHOLD | |
def get_user_model_names(user): | |
# Fetch the user's models | |
output_dir = f"./data/output" | |
users = [name for name in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, name))] | |
if user not in users: | |
# User does not exist | |
return [] | |
else: | |
# Fetch trained model names for the user | |
user_dir = f"./data/output/{user}" | |
user_models = [name for name in os.listdir(user_dir) if os.path.isdir(os.path.join(user_dir, name))] | |
user_models.sort() | |
return user_models | |
def get_unique_topics(): | |
return unique_topics | |
def get_large_clusters(min_n): | |
counts_df = system_preds_df.groupby(by=["topic_id"]).size().reset_index(name='counts') | |
counts_df = counts_df[counts_df["counts"] >= min_n] | |
return [topic_ids_to_topics[t_id] for t_id in sorted(counts_df["topic_id"].tolist()[1:])] | |
def get_ids_to_comments(): | |
return ids_to_comments | |
def get_workers_in_group(sel_gender, sel_race, sel_relig, sel_pol, sel_lgbtq): | |
df = worker_info_df.copy() | |
if sel_gender != "null": | |
df = df[df["gender"] == sel_gender] | |
if sel_relig != "null": | |
df = df[df["religion_important"] == sel_relig] | |
if sel_pol != "null": | |
df = df[df["political_affilation"] == sel_pol] | |
if sel_lgbtq != "null": | |
if sel_lgbtq == "LGBTQ+": | |
df = df[(df["lgbtq_status"] == "Homosexual") | (df["lgbtq_status"] == "Bisexual")] | |
else: | |
df = df[df["lgbtq_status"] == "Heterosexual"] | |
if sel_race != "": | |
df = df.dropna(subset=['race']) | |
for r in sel_race: | |
# Filter to rows with the indicated race | |
df = df[df["race"].str.contains(r)] | |
return df, len(df) | |
readable_to_internal = { | |
"Mean Absolute Error (MAE)": "MAE", | |
"Root Mean Squared Error (RMSE)": "RMSE", | |
"Mean Squared Error (MSE)": "MSE", | |
"Average rating difference": "avg_diff", | |
"Topic": "topic", | |
"Toxicity Category": "toxicity_category", | |
"Toxicity Severity": "toxicity_severity", | |
} | |
internal_to_readable = {v: k for k, v in readable_to_internal.items()} | |
def get_system_preds_df(): | |
return system_preds_df | |
######################################## | |
# Data storage helper functions | |
# Set up all directories for new user | |
def setup_user_dirs(cur_user): | |
user_dir = f"./data/output/{cur_user}" | |
if not os.path.isdir(user_dir): | |
os.mkdir(user_dir) | |
def setup_model_dirs(cur_user, cur_model): | |
model_dir = f"./data/output/{cur_user}/{cur_model}" | |
if not os.path.isdir(model_dir): | |
os.mkdir(model_dir) # Set up model dir | |
# Set up subdirs | |
os.mkdir(os.path.join(model_dir, "labels")) | |
os.mkdir(os.path.join(model_dir, "perf")) | |
def setup_user_model_dirs(cur_user, cur_model): | |
setup_user_dirs(cur_user) | |
setup_model_dirs(cur_user, cur_model) | |
# Charts | |
def get_chart_file(cur_user, cur_model): | |
chart_dir = f"./data/output/{cur_user}/{cur_model}" | |
return os.path.join(chart_dir, f"chart_overall_vis.json") | |
# Labels | |
def get_label_dir(cur_user, cur_model): | |
return f"./data/output/{cur_user}/{cur_model}/labels" | |
def get_n_label_files(cur_user, cur_model): | |
label_dir = get_label_dir(cur_user, cur_model) | |
return len([name for name in os.listdir(label_dir) if os.path.isfile(os.path.join(label_dir, name))]) | |
def get_label_file(cur_user, cur_model, label_i=None): | |
if label_i is None: | |
# Get index to add on to end of list | |
label_i = get_n_label_files(cur_user, cur_model) | |
label_dir = get_label_dir(cur_user, cur_model) | |
return os.path.join(label_dir, f"{label_i}.pkl") | |
# Performance | |
def get_perf_dir(cur_user, cur_model): | |
return f"./data/output/{cur_user}/{cur_model}/perf" | |
def get_n_perf_files(cur_user, cur_model): | |
perf_dir = get_perf_dir(cur_user, cur_model) | |
return len([name for name in os.listdir(perf_dir) if os.path.isfile(os.path.join(perf_dir, name))]) | |
def get_perf_file(cur_user, cur_model, perf_i=None): | |
if perf_i is None: | |
# Get index to add on to end of list | |
perf_i = get_n_perf_files(cur_user, cur_model) | |
perf_dir = get_perf_dir(cur_user, cur_model) | |
return os.path.join(perf_dir, f"{perf_i}.pkl") | |
# Predictions dataframe | |
def get_preds_file(cur_user, cur_model): | |
preds_dir = f"./data/output/{cur_user}/{cur_model}" | |
return os.path.join(preds_dir, f"preds_df.pkl") | |
# Reports | |
def get_reports_file(cur_user, cur_model): | |
return f"./data/output/{cur_user}/{cur_model}/reports.json" | |
######################################## | |
# General utils | |
def get_metric_ind(metric): | |
if metric == "MAE": | |
ind = 0 | |
elif metric == "MSE": | |
ind = 1 | |
elif metric == "RMSE": | |
ind = 2 | |
elif metric == "avg_diff": | |
ind = 3 | |
return ind | |
def my_bootstrap(vals, n_boot, alpha): | |
bs_samples = [] | |
sample_size = len(vals) | |
for i in range(n_boot): | |
samp = resample(vals, n_samples=sample_size) | |
bs_samples.append(np.median(samp)) | |
p = ((1.0 - alpha) / 2.0) * 100 | |
ci_low = np.percentile(bs_samples, p) | |
p = (alpha + ((1.0 - alpha) / 2.0)) * 100 | |
ci_high = np.percentile(bs_samples, p) | |
return bs_samples, (ci_low, ci_high) | |
######################################## | |
# GET_AUDIT utils | |
def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10): | |
hist, bin_edges = np.histogram(other_metric_vals, bins=n_bins, density=False) | |
data = pd.DataFrame({ | |
"bin_min": bin_edges[:-1], | |
"bin_max": bin_edges[1:], | |
"bin_count": hist, | |
"user_metric": [user_metric for i in range(len(hist))] | |
}) | |
base = alt.Chart(data) | |
bar = base.mark_bar(color=OTHER_USERS_COLOR).encode( | |
x=alt.X("bin_min", bin="binned", title=internal_to_readable[metric]), | |
x2='bin_max', | |
y=alt.Y("bin_count", title="Number of users"), | |
tooltip=[ | |
alt.Tooltip('bin_min', title=f'{metric} bin min', format=".2f"), | |
alt.Tooltip('bin_max', title=f'{metric} bin max', format=".2f"), | |
alt.Tooltip('bin_count', title=f'Number of OTHER users', format=","), | |
] | |
) | |
rule = base.mark_rule(color=YOUR_COLOR).encode( | |
x = "mean(user_metric):Q", | |
size=alt.value(2), | |
tooltip=[ | |
alt.Tooltip('mean(user_metric)', title=f'{metric} with YOUR labels', format=".2f"), | |
] | |
) | |
return (bar + rule).interactive() | |
# Generates the summary plot across all topics for the user | |
def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median", use_cache=True): | |
# Your perf (calculate using model and testset) | |
preds_file = get_preds_file(cur_user, cur_model) | |
with open(preds_file, "rb") as f: | |
preds_df = pickle.load(f) | |
chart_file = get_chart_file(cur_user, cur_model) | |
if use_cache and os.path.isfile(chart_file): | |
# Read from file if it exists | |
with open(chart_file, "r") as f: | |
topic_overview_plot_json = json.load(f) | |
else: | |
# Otherwise, generate chart and save to file | |
if topic_vis_method == "median": # Default | |
preds_df_grp = preds_df.groupby(["topic", "user_id"]).median() | |
elif topic_vis_method == "mean": | |
preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean() | |
topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model) | |
# Save to file | |
with open(chart_file, "w") as f: | |
json.dump(topic_overview_plot_json, f) | |
return { | |
"topic_overview_plot_json": json.loads(topic_overview_plot_json), | |
} | |
######################################## | |
# GET_LABELING utils | |
def create_example_sets(n_label_per_bin, score_bins, keyword=None, topic=None): | |
# Restrict to the keyword, if provided | |
df = system_preds_df.copy() | |
if keyword != None: | |
df = df[df["comment"].str.contains(keyword)] | |
if topic != None: | |
df = df[df["topic"] == topic] | |
# Try to choose n values from each provided score bin | |
ex_to_label = [] | |
bin_names = [] | |
bin_label_counts = [] | |
for i, score_bin in enumerate(score_bins): | |
min_score, max_score = score_bin | |
cur_df = df[(df["rating"] >= min_score) & (df["rating"] < max_score) & (df["item_id"].isin(train_df_ids))] | |
# sample rows for label | |
comment_ids = cur_df.item_id.tolist() | |
cur_n_label_per_bin = n_label_per_bin[i] | |
cap = min(len(comment_ids), (cur_n_label_per_bin)) | |
to_label = np.random.choice(comment_ids, cap, replace=False) | |
ex_to_label.extend(to_label) | |
bin_names.append(f"[{min_score}, {max_score})") | |
bin_label_counts.append(len(to_label)) | |
return ex_to_label | |
def get_grp_model_labels(n_label_per_bin, score_bins, grp_ids): | |
df = system_preds_df.copy() | |
train_df_grp = train_df[train_df["user_id"].isin(grp_ids)] | |
train_df_grp_avg = train_df_grp.groupby(by=["item_id"]).median().reset_index() | |
train_df_grp_avg_ids = train_df_grp_avg["item_id"].tolist() | |
ex_to_label = [] # IDs of comments to use for group model training | |
for i, score_bin in enumerate(score_bins): | |
min_score, max_score = score_bin | |
# get eligible comments to sample | |
cur_df = df[(df["rating"] >= min_score) & (df["rating"] < max_score) & (df["item_id"].isin(train_df_grp_avg_ids))] | |
comment_ids = cur_df.item_id.unique().tolist() | |
# sample comments | |
cur_n_label_per_bin = n_label_per_bin[i] | |
cap = min(len(comment_ids), (cur_n_label_per_bin)) | |
to_label = np.random.choice(comment_ids, cap, replace=False) | |
ex_to_label.extend((to_label)) | |
train_df_grp_avg = train_df_grp_avg[train_df_grp_avg["item_id"].isin(ex_to_label)] | |
ratings_grp = {ids_to_comments[int(r["item_id"])]: r["rating"] for _, r in train_df_grp_avg.iterrows()} | |
return ratings_grp | |
######################################## | |
# GET_PERSONALIZED_MODEL utils | |
def fetch_existing_data(user, model_name): | |
# Check if we have cached model performance | |
n_perf_files = get_n_perf_files(user, model_name) | |
if n_perf_files > 0: | |
# Fetch cached results | |
perf_file = get_perf_file(user, model_name, n_perf_files - 1) # Get last performance file | |
with open(perf_file, "rb") as f: | |
mae, mse, rmse, avg_diff = pickle.load(f) | |
else: | |
raise Exception(f"Model {model_name} does not exist") | |
# Fetch previous user-provided labels | |
ratings_prev = None | |
n_label_files = get_n_label_files(user, model_name) | |
if n_label_files > 0: | |
label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file | |
with open(label_file, "rb") as f: | |
ratings_prev = pickle.load(f) | |
return mae, mse, rmse, avg_diff, ratings_prev | |
# Main function called by server's `get_personalized_model` endpoint | |
# Trains an updated model with the specified name, user, and ratings | |
# Saves ratings, performance metrics, and pre-computed predictions to files | |
# - model_name: name of the model to train | |
# - ratings: dictionary of comments to ratings | |
# - user: user name | |
# - top_n: number of comments to train on (used when a set was held out for original user study) | |
# - topic: topic to train on (used when tuning for a specific topic) | |
def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False): | |
# Check if there is previously-labeled data; if so, combine it with this data | |
labeled_df = format_labeled_data(ratings, worker_id=user) # Treat ratings as full batch of all ratings | |
ratings_prev = None | |
# Filter out rows with "unsure" (-1) | |
labeled_df = labeled_df[labeled_df["rating"] != -1] | |
# Filter to top N for user study | |
if (topic is None) and (top_n is not None): | |
labeled_df = labeled_df.head(top_n) | |
else: | |
# For topic tuning, need to fetch old labels | |
n_label_files = get_n_label_files(user, model_name) | |
if n_label_files > 0: | |
# Concatenate previous set of labels with this new batch of labels | |
label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file | |
with open(label_file, "rb") as f: | |
ratings_prev = pickle.load(f) | |
labeled_df_prev = format_labeled_data(ratings_prev, worker_id=user) | |
labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1] | |
ratings.update(ratings_prev) # append old ratings to ratings | |
labeled_df = pd.concat([labeled_df_prev, labeled_df]) | |
if debug: | |
print("len ratings for training:", len(labeled_df)) | |
# Save this batch of labels | |
label_file = get_label_file(user, model_name) | |
with open(label_file, "wb") as f: | |
pickle.dump(ratings, f) | |
# Train model | |
cur_model, _, _, _ = train_user_model(ratings_df=labeled_df) | |
# Compute performance metrics | |
mae, mse, rmse, avg_diff = users_perf(cur_model, worker_id=user) | |
# Save performance metrics | |
perf_file = get_perf_file(user, model_name) | |
with open(perf_file, "wb") as f: | |
pickle.dump((mae, mse, rmse, avg_diff), f) | |
# Pre-compute predictions for full dataset | |
cur_preds_df = get_preds_df(cur_model, [user], sys_eval_df=ratings_df_full) | |
# Save pre-computed predictions | |
preds_file = get_preds_file(user, model_name) | |
with open(preds_file, "wb") as f: | |
pickle.dump(cur_preds_df, f) | |
# Replace cached summary plot if it exists | |
show_overall_perf(cur_model=model_name, error_type="Both", cur_user=user, use_cache=False) | |
ratings_prev = ratings | |
return mae, mse, rmse, avg_diff, ratings_prev | |
def format_labeled_data(ratings, worker_id): | |
all_rows = [] | |
for comment, rating in ratings.items(): | |
comment_id = comments_to_ids[comment] | |
row = [worker_id, comment_id, int(rating)] | |
all_rows.append(row) | |
df = pd.DataFrame(all_rows, columns=["user_id", "item_id", "rating"]) | |
return df | |
def users_perf(model, worker_id, sys_eval_df=sys_eval_df): | |
# Load the full empty dataset | |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist() | |
empty_ratings_rows = [[worker_id, c_id, 0] for c_id in sys_eval_comment_ids] | |
empty_ratings_df = pd.DataFrame(empty_ratings_rows, columns=["user_id", "item_id", "rating"]) | |
# Compute predictions for full dataset | |
reader = Reader(rating_scale=(0, 4)) | |
eval_set_data = Dataset.load_from_df(empty_ratings_df, reader) | |
_, testset = train_test_split(eval_set_data, test_size=1.) | |
predictions = model.test(testset) | |
df = empty_ratings_df # user_id, item_id, rating | |
user_item_preds = get_predictions_by_user_and_item(predictions) | |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1) | |
df = df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys')) | |
df.dropna(subset = ["pred"], inplace=True) | |
df["rating"] = df.rating.astype("int32") | |
perf_metrics = get_overall_perf(df, worker_id) # mae, mse, rmse, avg_diff | |
return perf_metrics | |
def get_overall_perf(preds_df, user_id): | |
# Prepare dataset to calculate performance | |
y_pred = preds_df[preds_df["user_id"] == user_id].rating_sys.to_numpy() # system's prediction | |
y_true = preds_df[preds_df["user_id"] == user_id].pred.to_numpy() # user's (predicted) ground truth | |
# Get performance for user's model | |
mae = mean_absolute_error(y_true, y_pred) | |
mse = mean_squared_error(y_true, y_pred) | |
rmse = mean_squared_error(y_true, y_pred, squared=False) | |
avg_diff = np.mean(y_true - y_pred) | |
return mae, mse, rmse, avg_diff | |
def get_predictions_by_user_and_item(predictions): | |
user_item_preds = {} | |
for uid, iid, true_r, est, _ in predictions: | |
user_item_preds[(uid, iid)] = est | |
return user_item_preds | |
# Pre-computes predictions for the provided model and specified users on the system-eval dataset | |
# - model: trained model | |
# - user_ids: list of user IDs to compute predictions for | |
# - sys_eval_df: dataframe of system eval labels (pre-computed) | |
def get_preds_df(model, user_ids, sys_eval_df=sys_eval_df, bins=BINS): | |
# Prep dataframe for all predictions we'd like to request | |
start = time.time() | |
sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist() | |
empty_ratings_rows = [] | |
for user_id in user_ids: | |
empty_ratings_rows.extend([[user_id, c_id, 0] for c_id in sys_eval_comment_ids]) | |
empty_ratings_df = pd.DataFrame(empty_ratings_rows, columns=["user_id", "item_id", "rating"]) | |
print("setup", time.time() - start) | |
# Evaluate model to get predictions | |
start = time.time() | |
reader = Reader(rating_scale=(0, 4)) | |
eval_set_data = Dataset.load_from_df(empty_ratings_df, reader) | |
_, testset = train_test_split(eval_set_data, test_size=1.) | |
predictions = model.test(testset) | |
print("train_test_split", time.time() - start) | |
# Update dataframe with predictions | |
start = time.time() | |
df = empty_ratings_df.copy() # user_id, item_id, rating | |
user_item_preds = get_predictions_by_user_and_item(predictions) | |
df["pred"] = df.apply(lambda row: user_item_preds[(row.user_id, row.item_id)] if (row.user_id, row.item_id) in user_item_preds else np.nan, axis=1) | |
df = df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys')) | |
df.dropna(subset = ["pred"], inplace=True) | |
df["rating"] = df.rating.astype("int32") | |
# Get binned predictions (based on user prediction) | |
df["prediction_bin"], out_bins = pd.cut(df["pred"], bins, labels=False, retbins=True) | |
df = df.sort_values(by=["item_id"]) | |
return df | |
# Given the full set of ratings, trains the specified model type and evaluates on the model eval set | |
# - ratings_df: dataframe of all ratings | |
# - train_df: dataframe of training labels | |
# - model_eval_df: dataframe of model eval labels (validation set) | |
# - train_frac: fraction of ratings to use for training | |
def train_user_model(ratings_df, train_df=train_df, model_eval_df=model_eval_df, train_frac=0.75, model_type="SVD", sim_type=None, user_based=True): | |
# Sample from shuffled labeled dataframe and add batch to train set; specified set size to model_eval set | |
labeled = ratings_df.sample(frac=1) # Shuffle the data | |
batch_size = math.floor(len(labeled) * train_frac) | |
labeled_train = labeled[:batch_size] | |
labeled_model_eval = labeled[batch_size:] | |
train_df_ext = train_df.append(labeled_train) | |
model_eval_df_ext = model_eval_df.append(labeled_model_eval) | |
# Train model and show model eval set results | |
model, perf = train_model(train_df_ext, model_eval_df_ext, model_type=model_type, sim_type=sim_type, user_based=user_based) | |
return model, perf, labeled_train, labeled_model_eval | |
# Given a set of labels split into training and validation (model_eval), trains the specified model type on the training labels and evaluates on the model_eval labels | |
# - train_df: dataframe of training labels | |
# - model_eval_df: dataframe of model eval labels (validation set) | |
# - model_type: type of model to train | |
def train_model(train_df, model_eval_df, model_type="SVD", sim_type=None, user_based=True): | |
# Train model | |
reader = Reader(rating_scale=(0, 4)) | |
train_data = Dataset.load_from_df(train_df, reader) | |
model_eval_data = Dataset.load_from_df(model_eval_df, reader) | |
train_set = train_data.build_full_trainset() | |
_, model_eval_set = train_test_split(model_eval_data, test_size=1.) | |
sim_options = { | |
"name": sim_type, | |
"user_based": user_based, # compute similarity between users or items | |
} | |
if model_type == "SVD": | |
algo = SVD() # SVD doesn't have similarity metric | |
elif model_type == "KNNBasic": | |
algo = KNNBasic(sim_options=sim_options) | |
elif model_type == "KNNWithMeans": | |
algo = KNNWithMeans(sim_options=sim_options) | |
elif model_type == "KNNWithZScore": | |
algo = KNNWithZScore(sim_options=sim_options) | |
algo.fit(train_set) | |
predictions = algo.test(model_eval_set) | |
rmse = accuracy.rmse(predictions) | |
fcp = accuracy.fcp(predictions) | |
mae = accuracy.mae(predictions) | |
mse = accuracy.mse(predictions) | |
print(f"MAE: {mae}, MSE: {mse}, RMSE: {rmse}, FCP: {fcp}") | |
perf = [mae, mse, rmse, fcp] | |
return algo, perf | |
def plot_train_perf_results(user, model_name, mae): | |
n_perf_files = get_n_perf_files(user, model_name) | |
all_rows = [] | |
for i in range(n_perf_files): | |
perf_file = get_perf_file(user, model_name, i) | |
with open(perf_file, "rb") as f: | |
mae, mse, rmse, avg_diff = pickle.load(f) | |
all_rows.append([i, mae, "Your MAE"]) | |
df = pd.DataFrame(all_rows, columns=["version", "perf", "metric"]) | |
chart = alt.Chart(df).mark_line(point=True).encode( | |
x="version:O", | |
y="perf", | |
color=alt.Color("metric", title="Performance metric"), | |
tooltip=[ | |
alt.Tooltip('version:O', title='Version'), | |
alt.Tooltip('metric:N', title="Metric"), | |
alt.Tooltip('perf:Q', title="Metric Value", format=".3f"), | |
], | |
).properties( | |
title=f"Performance over model versions: {model_name}", | |
width=500, | |
) | |
# Manually set for now | |
mae_good = 1.0 | |
mae_okay = 1.2 | |
plot_dim_width = 500 | |
domain_min = 0.0 | |
domain_max = 2.0 | |
bkgd = alt.Chart(pd.DataFrame({ | |
"start": [mae_okay, mae_good, domain_min], | |
"stop": [domain_max, mae_okay, mae_good], | |
"bkgd": ["Needs improvement", "Okay", "Good"], | |
})).mark_rect(opacity=0.2).encode( | |
y=alt.Y("start:Q", scale=alt.Scale(domain=[0, domain_max]), title=""), | |
y2=alt.Y2("stop:Q", title="Performance (MAE)"), | |
x=alt.value(0), | |
x2=alt.value(plot_dim_width), | |
color=alt.Color("bkgd:O", scale=alt.Scale( | |
domain=["Needs improvement", "Okay", "Good"], | |
range=["red", "yellow", "green"]), | |
title="How good is your MAE?" | |
) | |
) | |
plot = (bkgd + chart).properties(width=plot_dim_width).resolve_scale(color='independent') | |
mae_status = None | |
if mae < mae_good: | |
mae_status = "Your MAE is in the <b>Good</b> range. Your model looks ready to go." | |
elif mae < mae_okay: | |
mae_status = "Your MAE is in the <b>Okay</b> range. Your model can be used, but you can provide additional labels to improve it." | |
else: | |
mae_status = "Your MAE is in the <b>Needs improvement</b> range. Your model may need additional labels to improve." | |
return plot, mae_status | |
######################################## | |
# New visualizations | |
# Constants | |
VIS_BINS = np.round(np.arange(0, 4.01, 0.05), 3) | |
VIS_BINS_LABELS = [np.round(np.mean([x, y]), 3) for x, y in zip(VIS_BINS[:-1], VIS_BINS[1:])] | |
def get_key(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return "System agrees: Non-toxic" | |
elif sys > threshold and user > threshold: | |
return "System agrees: Toxic" | |
else: | |
if abs(sys - threshold) > 1.5: | |
return "System differs: Error > 1.5" | |
elif abs(sys - threshold) > 1.0: | |
return "System differs: Error > 1.0" | |
elif abs(sys - threshold) > 0.5: | |
return "System differs: Error > 0.5" | |
else: | |
return "System differs: Error <=0.5" | |
def get_key_no_model(sys, threshold): | |
if sys <= threshold: | |
return "System says: Non-toxic" | |
else: | |
return "System says: Toxic" | |
def get_user_color(user, threshold): | |
if user <= threshold: | |
return "#FFF" # white | |
else: | |
return "#808080" # grey | |
def get_system_color(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return "#FFF" # white | |
elif sys > threshold and user > threshold: | |
return "#808080" # grey | |
else: | |
if abs(sys - threshold) > 1.5: | |
return "#d62728" # red | |
elif abs(sys - threshold) > 1.0: | |
return "#ff7a5c" # med red | |
elif abs(sys - threshold) > 0.5: | |
return "#ffa894" # light red | |
else: | |
return "#ffd1c7" # very light red | |
def get_error_type(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return "No error (agree non-toxic)" | |
elif sys > threshold and user > threshold: | |
return "No error (agree toxic)" | |
elif sys <= threshold and user > threshold: | |
return "System may be under-sensitive" | |
elif sys > threshold and user <= threshold: | |
return "System may be over-sensitive" | |
def get_error_type_radio(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return "Show errors and non-errors" | |
elif sys > threshold and user > threshold: | |
return "Show errors and non-errors" | |
elif sys <= threshold and user > threshold: | |
return "System is under-sensitive" | |
elif sys > threshold and user <= threshold: | |
return "System is over-sensitive" | |
def get_error_magnitude(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return 0 # no classification error | |
elif sys > threshold and user > threshold: | |
return 0 # no classification error | |
elif sys <= threshold and user > threshold: | |
return abs(sys - user) | |
elif sys > threshold and user <= threshold: | |
return abs(sys - user) | |
def get_error_size(sys, user, threshold): | |
if sys <= threshold and user <= threshold: | |
return 0 # no classification error | |
elif sys > threshold and user > threshold: | |
return 0 # no classification error | |
elif sys <= threshold and user > threshold: | |
return sys - user | |
elif sys > threshold and user <= threshold: | |
return sys - user | |
def get_decision(rating, threshold): | |
if rating <= threshold: | |
return "Non-toxic" | |
else: | |
return "Toxic" | |
def get_category(row, threshold=0.3): | |
k_to_category = { | |
"is_profane_frac": "Profanity", | |
"is_threat_frac": "Threat", | |
"is_identity_attack_frac": "Identity Attack", | |
"is_insult_frac": "Insult", | |
"is_sexual_harassment_frac": "Sexual Harassment", | |
} | |
categories = [] | |
for k in ["is_profane_frac", "is_threat_frac", "is_identity_attack_frac", "is_insult_frac", "is_sexual_harassment_frac"]: | |
if row[k] > threshold: | |
categories.append(k_to_category[k]) | |
if len(categories) > 0: | |
return ", ".join(categories) | |
else: | |
return "" | |
def get_comment_url(row): | |
return f"#{row['item_id']}/#comment" | |
def get_topic_url(row): | |
return f"#{row['topic']}/#topic" | |
# Plots overall results histogram (each block is a topic) | |
def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"): | |
df = preds_df.copy().reset_index() | |
if n_topics is not None: | |
df = df[df["topic_id"] < n_topics] | |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True) | |
df = df[df["user_id"] == cur_user].sort_values(by=["item_id"]).reset_index() | |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()] | |
df["threshold"] = [threshold for r in df[sys_col].tolist()] | |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] | |
df["url"] = df.apply(lambda row: get_topic_url(row), axis=1) | |
# Plot sizing | |
domain_min = 0 | |
domain_max = 4 | |
plot_dim_height = 500 | |
plot_dim_width = 750 | |
max_items = np.max(df["vis_pred_bin"].value_counts().tolist()) | |
mark_size = np.round(plot_dim_height / max_items) * 8 | |
if mark_size > 75: | |
mark_size = 75 | |
plot_dim_height = 13 * max_items | |
# Main chart | |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.5).transform_window( | |
groupby=['vis_pred_bin'], | |
sort=[{'field': sys_col}], | |
id='row_number()', | |
ignorePeers=True, | |
).encode( | |
x=alt.X('vis_pred_bin:Q', title="Our prediction of your rating", scale=alt.Scale(domain=(domain_min, domain_max))), | |
y=alt.Y('id:O', title="Topics (ordered by System toxicity rating)", axis=alt.Axis(values=list(range(0, max_items, 5))), sort='descending'), | |
color = alt.Color("key:O", scale=alt.Scale( | |
domain=["System agrees: Non-toxic", "System agrees: Toxic", "System differs: Error > 1.5", "System differs: Error > 1.0", "System differs: Error > 0.5", "System differs: Error <=0.5"], | |
range=["white", "#cbcbcb", "red", "#ff7a5c", "#ffa894", "#ffd1c7"]), | |
title="System rating (box color)" | |
), | |
href="url:N", | |
tooltip = [ | |
alt.Tooltip("topic:N", title="Topic"), | |
alt.Tooltip("system_label:N", title="System label"), | |
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"), | |
alt.Tooltip("pred:Q", title="Your rating", format=".2f") | |
] | |
) | |
# Filter to specified error type | |
if error_type == "System is under-sensitive": | |
# FN: system rates non-toxic, but user rates toxic | |
chart = chart.transform_filter( | |
alt.FieldGTPredicate(field="pred", gt=threshold) | |
) | |
elif error_type == "System is over-sensitive": | |
# FP: system rates toxic, but user rates non-toxic | |
chart = chart.transform_filter( | |
alt.FieldLTEPredicate(field="pred", lte=threshold) | |
) | |
# Threshold line | |
rule = alt.Chart(pd.DataFrame({ | |
"threshold": [threshold], | |
"System threshold": [f"Threshold = {threshold}"] | |
})).mark_rule().encode( | |
x=alt.X("mean(threshold):Q", scale=alt.Scale(domain=(domain_min, domain_max)), title=""), | |
color=alt.Color("System threshold:N", scale=alt.Scale(domain=[f"Threshold = {threshold}"], range=["grey"])), | |
size=alt.value(2), | |
) | |
# Plot region annotations | |
nontoxic_x = (domain_min + threshold) / 2. | |
toxic_x = (domain_max + threshold) / 2. | |
annotation = alt.Chart(pd.DataFrame({ | |
"annotation_text": ["Non-toxic", "Toxic"], | |
"x": [nontoxic_x, toxic_x], | |
"y": [max_items, max_items], | |
})).mark_text( | |
align="center", | |
baseline="middle", | |
fontSize=16, | |
dy=10, | |
color="grey" | |
).encode( | |
x=alt.X("x", title=""), | |
y=alt.Y("y", title="", axis=None), | |
text="annotation_text" | |
) | |
# Plot region background colors | |
bkgd = alt.Chart(pd.DataFrame({ | |
"start": [domain_min, threshold], | |
"stop": [threshold, domain_max], | |
"bkgd": ["Non-toxic (L side)", "Toxic (R side)"], | |
})).mark_rect(opacity=1.0, stroke="grey", strokeWidth=0.25).encode( | |
x=alt.X("start:Q", scale=alt.Scale(domain=[domain_min, domain_max])), | |
x2=alt.X2("stop:Q"), | |
y=alt.value(0), | |
y2=alt.value(plot_dim_height), | |
color=alt.Color("bkgd:O", scale=alt.Scale( | |
domain=["Non-toxic (L side)", "Toxic (R side)"], | |
range=["white", "#cbcbcb"]), | |
title="Your rating (background color)" | |
) | |
) | |
plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json() | |
return plot | |
# Plots cluster results histogram (each block is a comment), but *without* a model | |
# as a point of reference (in contrast to plot_overall_vis_cluster) | |
def plot_overall_vis_cluster_no_model(cur_user, preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"): | |
df = preds_df.copy().reset_index() | |
df["vis_pred_bin"], out_bins = pd.cut(df[sys_col], bins, labels=VIS_BINS_LABELS, retbins=True) | |
df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index() | |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()] | |
df["key"] = [get_key_no_model(sys, threshold) for sys in df[sys_col].tolist()] | |
df["category"] = df.apply(lambda row: get_category(row), axis=1) | |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1) | |
if n_comments is not None: | |
n_to_sample = np.min([n_comments, len(df)]) | |
df = df.sample(n=n_to_sample) | |
# Plot sizing | |
domain_min = 0 | |
domain_max = 4 | |
plot_dim_height = 500 | |
plot_dim_width = 750 | |
max_items = np.max(df["vis_pred_bin"].value_counts().tolist()) | |
mark_size = np.round(plot_dim_height / max_items) * 8 | |
if mark_size > 75: | |
mark_size = 75 | |
plot_dim_height = 13 * max_items | |
# Main chart | |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window( | |
groupby=['vis_pred_bin'], | |
sort=[{'field': sys_col}], | |
id='row_number()', | |
ignorePeers=True | |
).encode( | |
x=alt.X('vis_pred_bin:Q', title="System toxicity rating", scale=alt.Scale(domain=(domain_min, domain_max))), | |
y=alt.Y('id:O', title="Comments (ordered by System toxicity rating)", axis=alt.Axis(values=list(range(0, max_items, 5))), sort='descending'), | |
color = alt.Color("key:O", scale=alt.Scale( | |
domain=["System says: Non-toxic", "System says: Toxic"], | |
range=["white", "#cbcbcb"]), | |
title="System rating", | |
legend=None, | |
), | |
href="url:N", | |
tooltip = [ | |
alt.Tooltip("comment:N", title="comment"), | |
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"), | |
] | |
) | |
# Threshold line | |
rule = alt.Chart(pd.DataFrame({ | |
"threshold": [threshold], | |
})).mark_rule(color='grey').encode( | |
x=alt.X("mean(threshold):Q", scale=alt.Scale(domain=[domain_min, domain_max]), title=""), | |
size=alt.value(2), | |
) | |
# Plot region annotations | |
nontoxic_x = (domain_min + threshold) / 2. | |
toxic_x = (domain_max + threshold) / 2. | |
annotation = alt.Chart(pd.DataFrame({ | |
"annotation_text": ["Non-toxic", "Toxic"], | |
"x": [nontoxic_x, toxic_x], | |
"y": [max_items, max_items], | |
})).mark_text( | |
align="center", | |
baseline="middle", | |
fontSize=16, | |
dy=10, | |
color="grey" | |
).encode( | |
x=alt.X("x", title=""), | |
y=alt.Y("y", title="", axis=None), | |
text="annotation_text" | |
) | |
# Plot region background colors | |
bkgd = alt.Chart(pd.DataFrame({ | |
"start": [domain_min, threshold], | |
"stop": [threshold, domain_max], | |
"bkgd": ["Non-toxic", "Toxic"], | |
})).mark_rect(opacity=1.0, stroke="grey", strokeWidth=0.25).encode( | |
x=alt.X("start:Q", scale=alt.Scale(domain=[domain_min, domain_max])), | |
x2=alt.X2("stop:Q"), | |
y=alt.value(0), | |
y2=alt.value(plot_dim_height), | |
color=alt.Color("bkgd:O", scale=alt.Scale( | |
domain=["Non-toxic", "Toxic"], | |
range=["white", "#cbcbcb"]), | |
title="System rating" | |
) | |
) | |
final_plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json() | |
return final_plot, df | |
# Plots cluster results histogram (each block is a comment) *with* a model as a point of reference | |
def plot_overall_vis_cluster(cur_user, preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, sys_col="rating_sys"): | |
df = preds_df.copy().reset_index() | |
df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True) | |
df = df[df["user_id"] == cur_user].sort_values(by=[sys_col]).reset_index(drop=True) | |
df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df[sys_col].tolist()] | |
df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] | |
df["category"] = df.apply(lambda row: get_category(row), axis=1) | |
df["url"] = df.apply(lambda row: get_comment_url(row), axis=1) | |
if n_comments is not None: | |
n_to_sample = np.min([n_comments, len(df)]) | |
df = df.sample(n=n_to_sample) | |
# Plot sizing | |
domain_min = 0 | |
domain_max = 4 | |
plot_dim_height = 500 | |
plot_dim_width = 750 | |
max_items = np.max(df["vis_pred_bin"].value_counts().tolist()) | |
mark_size = np.round(plot_dim_height / max_items) * 8 | |
if mark_size > 75: | |
mark_size = 75 | |
plot_dim_height = 13 * max_items | |
# Main chart | |
chart = alt.Chart(df).mark_square(opacity=0.8, size=mark_size, stroke="grey", strokeWidth=0.25).transform_window( | |
groupby=['vis_pred_bin'], | |
sort=[{'field': sys_col}], | |
id='row_number()', | |
ignorePeers=True | |
).encode( | |
x=alt.X('vis_pred_bin:Q', title="Our prediction of your rating", scale=alt.Scale(domain=(domain_min, domain_max))), | |
y=alt.Y('id:O', title="Comments (ordered by System toxicity rating)", axis=alt.Axis(values=list(range(0, max_items, 5))), sort='descending'), | |
color = alt.Color("key:O", scale=alt.Scale( | |
domain=["System agrees: Non-toxic", "System agrees: Toxic", "System differs: Error > 1.5", "System differs: Error > 1.0", "System differs: Error > 0.5", "System differs: Error <=0.5"], | |
range=["white", "#cbcbcb", "red", "#ff7a5c", "#ffa894", "#ffd1c7"]), | |
title="System rating (box color)" | |
), | |
href="url:N", | |
tooltip = [ | |
alt.Tooltip("comment:N", title="comment"), | |
alt.Tooltip(f"{sys_col}:Q", title="System rating", format=".2f"), | |
alt.Tooltip("pred:Q", title="Your rating", format=".2f"), | |
alt.Tooltip("category:N", title="Potential toxicity categories") | |
] | |
) | |
# Filter to specified error type | |
if error_type == "System is under-sensitive": | |
# FN: system rates non-toxic, but user rates toxic | |
chart = chart.transform_filter( | |
alt.FieldGTPredicate(field="pred", gt=threshold) | |
) | |
elif error_type == "System is over-sensitive": | |
# FP: system rates toxic, but user rates non-toxic | |
chart = chart.transform_filter( | |
alt.FieldLTEPredicate(field="pred", lte=threshold) | |
) | |
# Threshold line | |
rule = alt.Chart(pd.DataFrame({ | |
"threshold": [threshold], | |
})).mark_rule(color='grey').encode( | |
x=alt.X("mean(threshold):Q", scale=alt.Scale(domain=[domain_min, domain_max]), title=""), | |
size=alt.value(2), | |
) | |
# Plot region annotations | |
nontoxic_x = (domain_min + threshold) / 2. | |
toxic_x = (domain_max + threshold) / 2. | |
annotation = alt.Chart(pd.DataFrame({ | |
"annotation_text": ["Non-toxic", "Toxic"], | |
"x": [nontoxic_x, toxic_x], | |
"y": [max_items, max_items], | |
})).mark_text( | |
align="center", | |
baseline="middle", | |
fontSize=16, | |
dy=10, | |
color="grey" | |
).encode( | |
x=alt.X("x", title=""), | |
y=alt.Y("y", title="", axis=None), | |
text="annotation_text" | |
) | |
# Plot region background colors | |
bkgd = alt.Chart(pd.DataFrame({ | |
"start": [domain_min, threshold], | |
"stop": [threshold, domain_max], | |
"bkgd": ["Non-toxic (L side)", "Toxic (R side)"], | |
})).mark_rect(opacity=1.0, stroke="grey", strokeWidth=0.25).encode( | |
x=alt.X("start:Q", scale=alt.Scale(domain=[domain_min, domain_max])), | |
x2=alt.X2("stop:Q"), | |
y=alt.value(0), | |
y2=alt.value(plot_dim_height), | |
color=alt.Color("bkgd:O", scale=alt.Scale( | |
domain=["Non-toxic (L side)", "Toxic (R side)"], | |
range=["white", "#cbcbcb"]), | |
title="Your rating (background color)" | |
) | |
) | |
final_plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json() | |
return final_plot, df | |
def get_cluster_comments(df, error_type, threshold=TOXIC_THRESHOLD, sys_col="rating_sys", use_model=True): | |
df["user_color"] = [get_user_color(user, threshold) for user in df["pred"].tolist()] # get cell colors | |
df["system_color"] = [get_user_color(sys, threshold) for sys in df[sys_col].tolist()] # get cell colors | |
df["error_color"] = [get_system_color(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] # get cell colors | |
df["error_type"] = [get_error_type(sys, user, threshold) for sys, user in zip(df[sys_col].tolist(), df["pred"].tolist())] # get error type in words | |
df["error_amt"] = [abs(sys - threshold) for sys in df[sys_col].tolist()] # get raw error | |
df["judgment"] = ["" for _ in range(len(df))] # template for "agree" or "disagree" buttons | |
if use_model: | |
df = df.sort_values(by=["error_amt"], ascending=False) # surface largest errors first | |
else: | |
print("get_cluster_comments; not using model") | |
df = df.sort_values(by=[sys_col], ascending=True) | |
df["id"] = df["item_id"] | |
df["toxicity_category"] = df["category"] | |
df["user_rating"] = df["pred"] | |
df["user_decision"] = [get_decision(rating, threshold) for rating in df["pred"].tolist()] | |
df["system_rating"] = df[sys_col] | |
df["system_decision"] = [get_decision(rating, threshold) for rating in df[sys_col].tolist()] | |
df = df.round(decimals=2) | |
# Filter to specified error type | |
if error_type == "System is under-sensitive": | |
# FN: system rates non-toxic, but user rates toxic | |
df = df[df["error_type"] == "System may be under-sensitive"] | |
elif error_type == "System is over-sensitive": | |
# FP: system rates toxic, but user rates non-toxic | |
df = df[df["error_type"] == "System may be over-sensitive" ] | |
elif error_type == "Both": | |
df = df[(df["error_type"] == "System may be under-sensitive") | (df["error_type"] == "System may be over-sensitive")] | |
return df | |
# PERSONALIZED CLUSTERS utils | |
def get_disagreement_comments(preds_df, mode, n=10_000, threshold=TOXIC_THRESHOLD): | |
# Get difference between user rating and system rating | |
df = preds_df.copy() | |
df["diff"] = [get_error_size(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())] | |
df["error_type"] = [get_error_type(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())] | |
# asc = low to high; lowest = sys lower than user (under-sensitive) | |
# desc = high to low; lowest = sys higher than user (over-sensitive) | |
if mode == "under-sensitive": | |
df = df[df["error_type"] == "System may be under-sensitive"] | |
asc = True | |
elif mode == "over-sensitive": | |
df = df[df["error_type"] == "System may be over-sensitive"] | |
asc = False | |
df = df.sort_values(by=["diff"], ascending=asc) | |
df = df.head(n) | |
return df["comment"].tolist(), df | |
def get_explore_df(n_examples, threshold): | |
df = system_preds_df.sample(n=n_examples) | |
df["system_decision"] = [get_decision(rating, threshold) for rating in df["rating"].tolist()] | |
df["system_color"] = [get_user_color(sys, threshold) for sys in df["rating"].tolist()] # get cell colors | |
return df |