ZebraLogic / data_utils.py
yuchenlin's picture
inti commit
1c919b3
raw
history blame
1.33 kB
from datasets import load_dataset, Dataset
import os
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from constants import column_names, RANKING_COLUMN, ORDERED_COLUMN_NAMES
from utils_display import make_clickable_model
import random
disable_progress_bar()
import math
import json
from tqdm import tqdm
import numpy as np
id_to_data = None
model_len_info = None
bench_data = None
eval_results = None
score_eval_results = None
# Formats the columns
def formatter(x):
if type(x) is str:
x = x
else:
x = round(x, 1)
return x
def post_processing(df, column_names, rank_column=RANKING_COLUMN, ordered_columns=ORDERED_COLUMN_NAMES, click_url=True):
for col in df.columns:
if col == "Model" and click_url:
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x)))
else:
df[col] = df[col].apply(formatter) # For numerical values
if "Elo" in col:
df[col] = df[col].replace('-', np.nan).astype(float)
df.rename(columns=column_names, inplace=True)
list_columns = [col for col in ordered_columns if col in df.columns]
df = df[list_columns]
if rank_column in df.columns:
df.sort_values(by=rank_column, inplace=True, ascending=False)
return df