import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from typing import List | |
class bcolors: | |
PURPLE = '\033[95m' | |
BLUE = '\033[94m' | |
GREEN = '\033[92m' | |
WARNING = '\033[93m' | |
RED = '\033[91m' | |
ENDC = '\033[0m' | |
BOLD = '\033[1m' | |
UNDERLINE = '\033[4m' | |
def plot_ranks(r1: List, r2: List, r1_label: str, r2_label: str, output: str) -> plt.axes: | |
""" | |
e.g.: | |
df = rank_data(true_ranking, ranking, "actual", "predicted", "output") | |
""" | |
items = list(set(r1 + r2)) | |
xs = [] | |
for i in items: | |
for lbl, l in zip((r1_label, r2_label), (r1, r2)): | |
try: | |
x = l.index(i) | |
except ValueError: | |
x = np.nan | |
xs.append({"item": i, "version": lbl, "rank": x + 1}) | |
df = pd.DataFrame(xs).pivot(index="item", columns="version", values="rank").T | |
fig = plt.figure(figsize=(5, 10)) | |
bumpchart( | |
df, | |
show_rank_axis=False, | |
scatter=True, | |
ax=fig.gca(), | |
holes=False, | |
line_args={"linewidth": 5, "alpha": 0.5}, | |
scatter_args={"s": 100, "alpha": 0.8}, | |
) | |
plt.savefig(f"{output}.png", dpi=150, bbox_inches="tight") | |
return fig | |
def bumpchart( | |
df, | |
show_rank_axis=True, | |
rank_axis_distance=1.1, | |
ax=None, | |
scatter=False, | |
holes=False, | |
line_args={}, | |
scatter_args={}, | |
hole_args={}, | |
): | |
if ax is None: | |
left_yaxis = plt.gca() | |
else: | |
left_yaxis = ax | |
# Creating the right axis. | |
right_yaxis = left_yaxis.twinx() | |
axes = [left_yaxis, right_yaxis] | |
# Creating the far right axis if show_rank_axis is True | |
if show_rank_axis: | |
far_right_yaxis = left_yaxis.twinx() | |
axes.append(far_right_yaxis) | |
for col in df.columns: | |
y = df[col] | |
x = df.index.values | |
# Plotting blank points on the right axis/axes | |
# so that they line up with the left axis. | |
for axis in axes[1:]: | |
axis.plot(x, y, alpha=0) | |
left_yaxis.plot(x, y, **line_args, solid_capstyle="round") | |
# Adding scatter plots | |
if scatter: | |
left_yaxis.scatter(x, y, **scatter_args) | |
# Adding see-through holes | |
if holes: | |
bg_color = left_yaxis.get_facecolor() | |
left_yaxis.scatter(x, y, color=bg_color, **hole_args) | |
# Number of lines | |
lines = len(df.columns) | |
y_ticks = [*range(1, lines + 1)] | |
# Configuring the axes so that they line up well. | |
for axis in axes: | |
axis.invert_yaxis() | |
axis.set_yticks(y_ticks) | |
axis.set_ylim((lines + 0.5, 0.5)) | |
# Sorting the labels to match the ranks. | |
left_labels = df.iloc[0].sort_values().index | |
right_labels = df.iloc[-1].sort_values().index | |
left_yaxis.set_yticklabels(left_labels) | |
right_yaxis.set_yticklabels(right_labels) | |
# Setting the position of the far right axis so that it doesn't overlap with the right axis | |
if show_rank_axis: | |
far_right_yaxis.spines["right"].set_position(("axes", rank_axis_distance)) | |
return axes | |