import pandas as pd import matplotlib.pyplot as plt import numpy as np from typing import List class bcolors: PURPLE = '\033[95m' BLUE = '\033[94m' GREEN = '\033[92m' WARNING = '\033[93m' RED = '\033[91m' ENDC = '\033[0m' BOLD = '\033[1m' UNDERLINE = '\033[4m' def plot_ranks(r1: List, r2: List, r1_label: str, r2_label: str, output: str) -> plt.axes: """ e.g.: df = rank_data(true_ranking, ranking, "actual", "predicted", "output") """ items = list(set(r1 + r2)) xs = [] for i in items: for lbl, l in zip((r1_label, r2_label), (r1, r2)): try: x = l.index(i) except ValueError: x = np.nan xs.append({"item": i, "version": lbl, "rank": x + 1}) df = pd.DataFrame(xs).pivot(index="item", columns="version", values="rank").T fig = plt.figure(figsize=(5, 10)) bumpchart( df, show_rank_axis=False, scatter=True, ax=fig.gca(), holes=False, line_args={"linewidth": 5, "alpha": 0.5}, scatter_args={"s": 100, "alpha": 0.8}, ) plt.savefig(f"{output}.png", dpi=150, bbox_inches="tight") return fig def bumpchart( df, show_rank_axis=True, rank_axis_distance=1.1, ax=None, scatter=False, holes=False, line_args={}, scatter_args={}, hole_args={}, ): if ax is None: left_yaxis = plt.gca() else: left_yaxis = ax # Creating the right axis. right_yaxis = left_yaxis.twinx() axes = [left_yaxis, right_yaxis] # Creating the far right axis if show_rank_axis is True if show_rank_axis: far_right_yaxis = left_yaxis.twinx() axes.append(far_right_yaxis) for col in df.columns: y = df[col] x = df.index.values # Plotting blank points on the right axis/axes # so that they line up with the left axis. for axis in axes[1:]: axis.plot(x, y, alpha=0) left_yaxis.plot(x, y, **line_args, solid_capstyle="round") # Adding scatter plots if scatter: left_yaxis.scatter(x, y, **scatter_args) # Adding see-through holes if holes: bg_color = left_yaxis.get_facecolor() left_yaxis.scatter(x, y, color=bg_color, **hole_args) # Number of lines lines = len(df.columns) y_ticks = [*range(1, lines + 1)] # Configuring the axes so that they line up well. for axis in axes: axis.invert_yaxis() axis.set_yticks(y_ticks) axis.set_ylim((lines + 0.5, 0.5)) # Sorting the labels to match the ranks. left_labels = df.iloc[0].sort_values().index right_labels = df.iloc[-1].sort_values().index left_yaxis.set_yticklabels(left_labels) right_yaxis.set_yticklabels(right_labels) # Setting the position of the far right axis so that it doesn't overlap with the right axis if show_rank_axis: far_right_yaxis.spines["right"].set_position(("axes", rank_axis_distance)) return axes