hma / common /plot /plot_arch_ablation.py
LeroyWaa's picture
draft
246c106
raw
history blame
1.59 kB
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Sample data based on the provided image structure
tasks = [ "Add", "Concat", "Cross Attention", "Modulation"]
values = np.array([
[6.35],
[5.68],
[5.26],
[5.02],
# [0.87, 0.55, 0.25, 0.03, 0.01, 0.0]
])
values = np.exp(values)
# Bar colors matching the provided image
bar_colors = ['#1f78b4', '#ffffff', '#a6cee3', '#cab2d6', '#b3b3cc', '#33a02c']
# Plotting the data
fig, ax = plt.subplots(figsize=(5, 3))
# Set bar width and x positions for each group
bar_width = 0.4
x = np.arange(len(tasks))
# Plot each group's bars with the specified colors
for i in range(values.shape[1]):
bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
for container in ax.containers:
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.1f")
bars[-1].set_color('#cab2d6')
bars[-1].set_edgecolor('black')
# Set titles, labels, and ticks
# ax.set_title("Zero-Shot Performance Comparison Across Tasks")
ax.set_xlabel("Model", fontsize=14)
ax.set_ylabel("Perplexity", fontsize=14)
ax.set_xticks(x )
ax.tick_params(axis='x', rotation=15)
ax.set_xticklabels(tasks, fontsize=12)
ax.set_ylim(values.min() - 10, values.max() + 50)
# Adding the legend outside the plot area
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
# Display the plot
plt.tight_layout()
# plt.show()
plt.savefig("output/arch_ablation.png", dpi=300)