Spaces:
Build error
Build error
import os | |
import re | |
import pandas as pd | |
import evaluate | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
bleu = evaluate.load("bleu") | |
rouge = evaluate.load("rouge") | |
meteor = evaluate.load("meteor") | |
accuracy = evaluate.load("accuracy") | |
def extract_answer(text, debug=False): | |
if text: | |
# Remove the begin and end tokens | |
text = re.sub( | |
r".*?(assistant|\[/INST\]).+?\b", "", text, flags=re.DOTALL | re.MULTILINE | |
) | |
if debug: | |
print("--------\nstep 1:", text) | |
text = re.sub(r"<.+?>.*", "", text, flags=re.DOTALL | re.MULTILINE) | |
if debug: | |
print("--------\nstep 2:", text) | |
text = re.sub( | |
r".*?end_header_id\|>\n\n", "", text, flags=re.DOTALL | re.MULTILINE | |
) | |
if debug: | |
print("--------\nstep 3:", text) | |
return text | |
def calc_metrics(references, predictions, debug=False): | |
assert len(references) == len( | |
predictions | |
), f"lengths are difference: {len(references)} != {len(predictions)}" | |
predictions = [extract_answer(text) for text in predictions] | |
correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)] | |
accuracy = sum(correct) / len(references) | |
results = {"accuracy": accuracy} | |
if debug: | |
correct_ids = [i for i, c in enumerate(correct) if c == 1] | |
results["correct_ids"] = correct_ids | |
results["meteor"] = meteor.compute(predictions=predictions, references=references)[ | |
"meteor" | |
] | |
results["bleu_scores"] = bleu.compute( | |
predictions=predictions, references=references, max_order=4 | |
) | |
results["rouge_scores"] = rouge.compute( | |
predictions=predictions, references=references | |
) | |
return results | |
def save_results(model_name, results_path, dataset, predictions, debug=False): | |
if not os.path.exists(results_path): | |
# Get the directory part of the file path | |
dir_path = os.path.dirname(results_path) | |
# Create all directories in the path (if they don't exist) | |
os.makedirs(dir_path, exist_ok=True) | |
df = dataset.to_pandas() | |
df.drop(columns=["text", "prompt"], inplace=True) | |
else: | |
df = pd.read_csv(results_path, on_bad_lines="warn") | |
df[model_name] = predictions | |
if debug: | |
print(df.head(1)) | |
df.to_csv(results_path, index=False) | |
def get_metrics(df): | |
metrics_df = pd.DataFrame(df.columns.T)[2:] | |
metrics_df.rename(columns={0: "model"}, inplace=True) | |
metrics_df["model"] = metrics_df["model"].apply(lambda x: x.split("/")[-1]) | |
metrics_df.reset_index(inplace=True) | |
metrics_df = metrics_df.drop(columns=["index"]) | |
accuracy = [] | |
meteor = [] | |
bleu_1 = [] | |
rouge_l = [] | |
all_metrics = [] | |
for col in df.columns[2:]: | |
metrics = calc_metrics(df["english"], df[col], debug=True) | |
print(f"{col}: {metrics}") | |
accuracy.append(metrics["accuracy"]) | |
meteor.append(metrics["meteor"]) | |
bleu_1.append(metrics["bleu_scores"]["bleu"]) | |
rouge_l.append(metrics["rouge_scores"]["rougeL"]) | |
all_metrics.append(metrics) | |
metrics_df["accuracy"] = accuracy | |
metrics_df["meteor"] = meteor | |
metrics_df["bleu_1"] = bleu_1 | |
metrics_df["rouge_l"] = rouge_l | |
metrics_df["all_metrics"] = all_metrics | |
return metrics_df | |
def plot_metrics(metrics_df, figsize=(14, 5), ylim=(0, 0.44)): | |
plt.figure(figsize=figsize) | |
df_melted = pd.melt( | |
metrics_df, id_vars="model", value_vars=["meteor", "bleu_1", "rouge_l"] | |
) | |
barplot = sns.barplot(x="variable", y="value", hue="model", data=df_melted) | |
# Set different hatches for each model | |
hatches = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\"] | |
# Create a dictionary to map models to hatches | |
model_hatches = { | |
model: hatches[i % len(hatches)] | |
for i, model in enumerate(metrics_df["model"].unique()) | |
} | |
# Apply hatches based on the model | |
num_vars = len(df_melted["variable"].unique()) | |
for i, bar in enumerate(barplot.patches): | |
model = df_melted["model"].iloc[i // num_vars] | |
bar.set_hatch(model_hatches[model]) | |
# Manually update legend to match the bar hatches | |
handles, labels = barplot.get_legend_handles_labels() | |
for handle, model in zip(handles, metrics_df["model"].unique()): | |
handle.set_hatch(model_hatches[model]) | |
barplot.set_xticklabels(["METEOR", "BLEU-1", "ROUGE-L"]) | |
for p in barplot.patches: | |
if p.get_height() == 0: | |
continue | |
barplot.annotate( | |
f"{p.get_height():.2f}", | |
(p.get_x() + p.get_width() / 2.0, p.get_height()), | |
ha="center", | |
va="center", | |
xytext=(0, 10), | |
textcoords="offset points", | |
) | |
barplot.set(ylim=ylim, ylabel="Scores", xlabel="Metrics") | |
plt.legend(bbox_to_anchor=(0.5, -0.1), loc="upper center") | |
plt.show() | |
def plot_times(perf_df, ylim=0.421): | |
# Adjusted code to put "train-time" bars in red at the bottom | |
fig, ax1 = plt.subplots(figsize=(12, 10)) | |
color_train = "tab:red" | |
color_eval = "orange" | |
ax1.set_xlabel("Models") | |
ax1.set_ylabel("Time (mins)") | |
ax1.set_xticks(range(len(perf_df["model"]))) # Set x-ticks positions | |
ax1.set_xticklabels(perf_df["model"], rotation=90) | |
# Plot "train-time" first so it's at the bottom | |
ax1.bar( | |
perf_df["model"], | |
perf_df["train-time(mins)"], | |
color=color_train, | |
label="train-time", | |
) | |
# Then, plot "eval-time" on top of "train-time" | |
ax1.bar( | |
perf_df["model"], | |
perf_df["eval-time(mins)"], | |
bottom=perf_df["train-time(mins)"], | |
color=color_eval, | |
label="eval-time", | |
) | |
ax1.tick_params(axis="y") | |
ax1.legend(loc="upper left") | |
if "meteor" in perf_df.columns: | |
ax2 = ax1.twinx() | |
color_meteor = "tab:blue" | |
ax2.set_ylabel("METEOR", color=color_meteor) | |
ax2.plot( | |
perf_df["model"], | |
perf_df["meteor"], | |
color=color_meteor, | |
marker="o", | |
label="meteor", | |
) | |
ax2.tick_params(axis="y", labelcolor=color_meteor) | |
ax2.legend(loc="upper right") | |
ax2.set_ylim(ax2.get_ylim()[0], ylim) | |
# Show numbers in bars | |
for p in ax1.patches: | |
height = p.get_height() | |
if height == 0: # Skip bars with height 0 | |
continue | |
ax1.annotate( | |
f"{height:.2f}", | |
(p.get_x() + p.get_width() / 2.0, p.get_y() + height), | |
ha="center", | |
va="center", | |
xytext=(0, -10), | |
textcoords="offset points", | |
) | |
fig.tight_layout() | |
plt.show() | |