Spaces:
AIR-Bench
/
Running on CPU Upgrade

leaderboard / tests /test_utils.py
nan's picture
refactor: reformat with black
ec8e2d4
raw
history blame
4.01 kB
import pandas as pd
import pytest
from app import update_table
from src.envs import (
COL_NAME_AVG,
COL_NAME_IS_ANONYMOUS,
COL_NAME_RANK,
COL_NAME_RERANKING_MODEL,
COL_NAME_RETRIEVAL_MODEL,
COL_NAME_REVISION,
COL_NAME_TIMESTAMP,
)
from src.utils import (
filter_models,
filter_queries,
get_default_cols,
get_iso_format_timestamp,
search_table,
select_columns,
update_table_long_doc,
)
@pytest.fixture
def toy_df():
return pd.DataFrame(
{
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
"wiki_en": [0.8, 0.7, 0.2, 0.1],
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
"news_en": [0.8, 0.7, 0.2, 0.1],
"news_zh": [0.4, 0.1, 0.4, 0.3],
}
)
@pytest.fixture
def toy_df_long_doc():
return pd.DataFrame(
{
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
"law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
}
)
def test_filter_models(toy_df):
df_result = filter_models(
toy_df,
[
"bge-reranker-v2-m3",
],
)
assert len(df_result) == 2
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
def test_search_table(toy_df):
df_result = search_table(toy_df, "jina")
assert len(df_result) == 2
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
def test_filter_queries(toy_df):
df_result = filter_queries("jina", toy_df)
assert len(df_result) == 2
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
def test_select_columns(toy_df):
df_result = select_columns(
toy_df,
[
"news",
],
[
"zh",
],
)
assert len(df_result.columns) == 4
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
def test_update_table_long_doc(toy_df_long_doc):
df_result = update_table_long_doc(
toy_df_long_doc,
[
"law",
],
[
"en",
],
[
"bge-reranker-v2-m3",
],
"jina",
)
print(df_result)
def test_get_iso_format_timestamp():
timestamp_config, timestamp_fn = get_iso_format_timestamp()
assert len(timestamp_fn) == 14
assert len(timestamp_config) == 20
assert timestamp_config[-1] == "Z"
def test_get_default_cols():
cols, types = get_default_cols("qa")
for c, t in zip(cols, types):
print(f"type({c}): {t}")
assert len(frozenset(cols)) == len(cols)
def test_update_table():
df = pd.DataFrame(
{
COL_NAME_IS_ANONYMOUS: [False, False, False],
COL_NAME_REVISION: ["a1", "a2", "a3"],
COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
COL_NAME_RANK: [1, 2, 3],
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
"wiki_en": [0.1, 0.2, 0.3],
}
)
results = update_table(
df,
"wiki",
"en",
["NoReranker"],
"",
show_anonymous=False,
reset_ranking=False,
show_revision_and_timestamp=False,
)
# keep the RANK as the same regardless of the unsorted averages
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]