File size: 2,072 Bytes
ff66cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import matplotlib as mpl
mpl.use("Agg")
import argparse
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import IPython
font = {
"size": 22,
}
matplotlib.rc("font", **font)
sns.set_context("paper", font_scale=2.0)
def mkdir_if_missing(dst_dir):
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
def save_figure(name, title=""):
if len(title) > 0:
plt.title(title)
plt.tight_layout()
print(f"output/output_figures/{name[:30]}")
mkdir_if_missing(f"output/output_figures/{name[:30]}")
plt.savefig(f"output/output_figures/{name[:30]}/output.png")
plt.clf()
def main(multirun_out, title):
dfs = []
suffix = ""
run_num = 0
for rundir in (sorted(multirun_out.split(","))):
runpath = os.path.join('output/output_stats', rundir)
statspath = os.path.join(runpath, "eval_results.csv")
if os.path.exists(statspath):
run_num += 1
df = pd.read_csv(statspath)
# print(df)
# df.drop(df.iloc[-1], axis=0, inplace=True)
# df.drop('diversity', axis=1)
dfs.append(df)
else:
print("skip:", statspath)
# merge dfs, which have shared column names
df = pd.concat(dfs)
print(df.iloc)
title += f" run: {run_num} "
# rewards
fig, ax = plt.subplots(figsize=(16, 8))
sns_plot = sns.barplot(
data=df, x="metric", y="success", hue='model', errorbar=("sd", 1), palette="deep", hue_order=["gpt3", "gpt3-finetuned", "gpt3.5", "gpt3.5-finetuned", "gpt4"]
)
# label texts
for container in ax.containers:
ax.bar_label(container, label_type="center", fontsize="x-large", fmt="%.2f")
# save plot
save_figure(f"{multirun_out}_{title}{suffix}", title)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--multirun_out", type=str)
parser.add_argument("--title", type=str, default="")
args = parser.parse_args()
main(args.multirun_out, args.title)
|