tofu_leaderboard / plotter.py
pratyushmaini
big df
cf8c271
raw
history blame
3.26 kB
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy.stats as stats
import warnings
warnings.simplefilter("ignore", category=Warning)
def custom_agg(x):
result = stats.hmean(x)
return result
def create_plots(big_df, selected_methods):
big_df = big_df[big_df['Method'].isin(selected_methods)]
# we want 1-Rouge-P
big_df["ROUGE-P Forget"] = 1 - big_df["ROUGE-P Forget"]
metrics = list(big_df.columns)
metrics.remove("Method")
metrics.remove("Model")
metrics.remove("Forget Rate")
metrics.remove("LR")
metrics.remove("Epoch")
metrics.remove("Compute")
print(metrics)
# Apply the custom aggregation function across each row, excluding the first column
row_custom_agg = big_df.iloc[:, -len(metrics):].apply(custom_agg, axis=1)
# If you want to add these results back to your original DataFrame
big_df['MAPO'] = row_custom_agg
big_df["LR"] = big_df["LR"].astype(float)
# big_df = big_df[big_df["LR"] >= 1e-5]
big_df["ROUGE-P Forget"] = 1 - big_df["ROUGE-P Forget"]
big_df.reset_index(inplace=True)
print(big_df[["Method", "Model", "Forget Rate", "LR", "Epoch", "ROUGE-P Forget", "MAPO"]].round(2).to_markdown())
# print(big_df.groupby(['Method', 'Model', 'Forget Rate']).head())
result = big_df.loc[big_df.groupby(['Method', 'Model', 'Forget Rate'])['MAPO'].idxmax()]
print(result[["Method", "Model", "Forget Rate", "LR", "Epoch", "MAPO"]].round(6).to_markdown())
# exit()
plot_legend = False
fs = 18 if plot_legend else 22
metrics.append("MAPO")
# Set the style of the visualization
sns.set_theme(style="whitegrid")
plt.rcParams['font.family'] = 'Times New Roman'
for metric_to_plot in metrics:
sub_df = result[big_df["Model"] == "Llama-2-7B"]
fig, ax = plt.subplots(figsize=(15, 5))
sns.barplot(x="Method", y=metric_to_plot, hue="Forget Rate", data=sub_df, ax=ax, legend=plot_legend)
ax.set_ylabel(metric_to_plot, fontsize=fs)
ax.set_ylim(0.0, 1.0)
ax.set_xlabel("", fontsize=fs)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=fs)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=fs-4)
ax.spines[['right', 'top']].set_visible(False)
if plot_legend:
plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), title="Forget Rate (%)")
plt.title(metric_to_plot + " on Llama-2-7B", fontsize=fs)
plt.tight_layout()
plt.savefig(f"barplots/{metric_to_plot}-Llama-2-7B{'legend' if plot_legend else ''}.pdf")
print(f"\includegraphics[width=\\textwidth]{{figures/barplots/{metric_to_plot}-Llama-2-7B{'legend' if plot_legend else ''}.pdf}}")
plt.close(fig)
for model in ["Llama-2-7B", "Phi"]:
sub_df = result[result["Model"] == model][["Method", "Forget Rate", "MAPO"]]
# print(sub_df.round(6).to_latex(index=False))
sub_df.reset_index(inplace=True)
# Reorienting the dataframe
sub_df_reoriented = sub_df.pivot(index="Method", columns='Forget Rate', values='MAPO')
# Output a latex table of the MAPO values by Method and Forget Rate
print(sub_df_reoriented.round(4).to_latex(index=True))