rahulnair23's picture
test commit
0de1d17
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from typing import List
class bcolors:
PURPLE = '\033[95m'
BLUE = '\033[94m'
GREEN = '\033[92m'
WARNING = '\033[93m'
RED = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def plot_ranks(r1: List, r2: List, r1_label: str, r2_label: str, output: str) -> plt.axes:
"""
e.g.:
df = rank_data(true_ranking, ranking, "actual", "predicted", "output")
"""
items = list(set(r1 + r2))
xs = []
for i in items:
for lbl, l in zip((r1_label, r2_label), (r1, r2)):
try:
x = l.index(i)
except ValueError:
x = np.nan
xs.append({"item": i, "version": lbl, "rank": x + 1})
df = pd.DataFrame(xs).pivot(index="item", columns="version", values="rank").T
fig = plt.figure(figsize=(5, 10))
bumpchart(
df,
show_rank_axis=False,
scatter=True,
ax=fig.gca(),
holes=False,
line_args={"linewidth": 5, "alpha": 0.5},
scatter_args={"s": 100, "alpha": 0.8},
)
plt.savefig(f"{output}.png", dpi=150, bbox_inches="tight")
return fig
def bumpchart(
df,
show_rank_axis=True,
rank_axis_distance=1.1,
ax=None,
scatter=False,
holes=False,
line_args={},
scatter_args={},
hole_args={},
):
if ax is None:
left_yaxis = plt.gca()
else:
left_yaxis = ax
# Creating the right axis.
right_yaxis = left_yaxis.twinx()
axes = [left_yaxis, right_yaxis]
# Creating the far right axis if show_rank_axis is True
if show_rank_axis:
far_right_yaxis = left_yaxis.twinx()
axes.append(far_right_yaxis)
for col in df.columns:
y = df[col]
x = df.index.values
# Plotting blank points on the right axis/axes
# so that they line up with the left axis.
for axis in axes[1:]:
axis.plot(x, y, alpha=0)
left_yaxis.plot(x, y, **line_args, solid_capstyle="round")
# Adding scatter plots
if scatter:
left_yaxis.scatter(x, y, **scatter_args)
# Adding see-through holes
if holes:
bg_color = left_yaxis.get_facecolor()
left_yaxis.scatter(x, y, color=bg_color, **hole_args)
# Number of lines
lines = len(df.columns)
y_ticks = [*range(1, lines + 1)]
# Configuring the axes so that they line up well.
for axis in axes:
axis.invert_yaxis()
axis.set_yticks(y_ticks)
axis.set_ylim((lines + 0.5, 0.5))
# Sorting the labels to match the ranks.
left_labels = df.iloc[0].sort_values().index
right_labels = df.iloc[-1].sort_values().index
left_yaxis.set_yticklabels(left_labels)
right_yaxis.set_yticklabels(right_labels)
# Setting the position of the far right axis so that it doesn't overlap with the right axis
if show_rank_axis:
far_right_yaxis.spines["right"].set_position(("axes", rank_axis_distance))
return axes