Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,002 Bytes
9c49811 7845083 a3d4c8d ec8e2d4 6f9f649 ec8e2d4 9c49811 ec8e2d4 9c49811 f30cbcc ec8e2d4 f30cbcc ec8e2d4 9c49811 ec8e2d4 9c49811 ec8e2d4 9c49811 ec8e2d4 f30cbcc 6f9f649 ec8e2d4 5664d71 df659d0 5eb510c df659d0 6d7eea4 7ca7624 ec8e2d4 7ca7624 ec8e2d4 7ca7624 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import pandas as pd
import pytest
from app import update_table
from src.columns import (
COL_NAME_AVG,
COL_NAME_IS_ANONYMOUS,
COL_NAME_RANK,
COL_NAME_RERANKING_MODEL,
COL_NAME_RETRIEVAL_MODEL,
COL_NAME_REVISION,
COL_NAME_TIMESTAMP,
)
from src.utils import (
filter_models,
filter_queries,
get_default_cols,
get_iso_format_timestamp,
search_table,
select_columns,
update_doc_df_elem,
)
@pytest.fixture
def toy_df():
return pd.DataFrame(
{
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
"wiki_en": [0.8, 0.7, 0.2, 0.1],
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
"news_en": [0.8, 0.7, 0.2, 0.1],
"news_zh": [0.4, 0.1, 0.4, 0.3],
}
)
@pytest.fixture
def toy_df_long_doc():
return pd.DataFrame(
{
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
"law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
}
)
def test_filter_models(toy_df):
df_result = filter_models(
toy_df,
[
"bge-reranker-v2-m3",
],
)
assert len(df_result) == 2
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
def test_search_table(toy_df):
df_result = search_table(toy_df, "jina")
assert len(df_result) == 2
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
def test_filter_queries(toy_df):
df_result = filter_queries("jina", toy_df)
assert len(df_result) == 2
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
def test_select_columns(toy_df):
df_result = select_columns(
toy_df,
[
"news",
],
[
"zh",
],
)
assert len(df_result.columns) == 4
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
def test_update_table_long_doc(toy_df_long_doc):
df_result = update_doc_df_elem(
toy_df_long_doc,
[
"law",
],
[
"en",
],
[
"bge-reranker-v2-m3",
],
"jina",
)
print(df_result)
def test_get_iso_format_timestamp():
timestamp_config, timestamp_fn = get_iso_format_timestamp()
assert len(timestamp_fn) == 14
assert len(timestamp_config) == 20
assert timestamp_config[-1] == "Z"
def test_get_default_cols():
cols, types = get_default_cols("qa")
for c, t in zip(cols, types):
print(f"type({c}): {t}")
assert len(frozenset(cols)) == len(cols)
def test_update_table():
df = pd.DataFrame(
{
COL_NAME_IS_ANONYMOUS: [False, False, False],
COL_NAME_REVISION: ["a1", "a2", "a3"],
COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
COL_NAME_RANK: [1, 2, 3],
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
"wiki_en": [0.1, 0.2, 0.3],
}
)
results = update_table(
df,
"wiki",
"en",
["NoReranker"],
"",
show_anonymous=False,
reset_ranking=False,
show_revision_and_timestamp=False,
)
# keep the RANK as the same regardless of the unsorted averages
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
|