File size: 10,137 Bytes
e13b3b8 |
1 |
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: mini_leaderboard"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio pandas "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('assets')\n", "!wget -q -O assets/__init__.py https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/__init__.py\n", "!wget -q -O assets/custom_css.css https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/custom_css.css\n", "!wget -q -O assets/leaderboard_data.json https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/leaderboard_data.json"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pandas as pd\n", "from pathlib import Path\n", "\n", "abs_path = Path(__file__).parent.absolute()\n", "\n", "df = pd.read_json(str(abs_path / \"assets/leaderboard_data.json\"))\n", "invisible_df = df.copy()\n", "\n", "COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"Type\",\n", " \"Architecture\",\n", " \"Precision\",\n", " \"Merged\",\n", " \"Hub License\",\n", " \"#Params (B)\",\n", " \"Hub \u2764\ufe0f\",\n", " \"Model sha\",\n", " \"model_name_for_query\",\n", "]\n", "ON_LOAD_COLS = [\n", " \"T\",\n", " \"Model\",\n", " \"Average \u2b06\ufe0f\",\n", " \"ARC\",\n", " \"HellaSwag\",\n", " \"MMLU\",\n", " \"TruthfulQA\",\n", " \"Winogrande\",\n", " \"GSM8K\",\n", " \"model_name_for_query\",\n", "]\n", "TYPES = [\n", " \"str\",\n", " \"markdown\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"number\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"str\",\n", " \"bool\",\n", " \"str\",\n", " \"number\",\n", " \"number\",\n", " \"bool\",\n", " \"str\",\n", " \"bool\",\n", " \"bool\",\n", " \"str\",\n", "]\n", "NUMERIC_INTERVALS = {\n", " \"?\": pd.Interval(-1, 0, closed=\"right\"),\n", " \"~1.5\": pd.Interval(0, 2, closed=\"right\"),\n", " \"~3\": pd.Interval(2, 4, closed=\"right\"),\n", " \"~7\": pd.Interval(4, 9, closed=\"right\"),\n", " \"~13\": pd.Interval(9, 20, closed=\"right\"),\n", " \"~35\": pd.Interval(20, 45, closed=\"right\"),\n", " \"~60\": pd.Interval(45, 70, closed=\"right\"),\n", " \"70+\": pd.Interval(70, 10000, closed=\"right\"),\n", "}\n", "MODEL_TYPE = [str(s) for s in df[\"T\"].unique()]\n", "Precision = [str(s) for s in df[\"Precision\"].unique()]\n", "\n", "# Searching and filtering\n", "def update_table(\n", " hidden_df: pd.DataFrame,\n", " columns: list,\n", " type_query: list,\n", " precision_query: str,\n", " size_query: list,\n", " query: str,\n", "):\n", " filtered_df = filter_models(hidden_df, type_query, size_query, precision_query) # type: ignore\n", " filtered_df = filter_queries(query, filtered_df)\n", " df = select_columns(filtered_df, columns)\n", " return df\n", "\n", "def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:\n", " return df[(df[\"model_name_for_query\"].str.contains(query, case=False))] # type: ignore\n", "\n", "def select_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:\n", " # We use COLS to maintain sorting\n", " filtered_df = df[[c for c in COLS if c in df.columns and c in columns]]\n", " return filtered_df # type: ignore\n", "\n", "def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame:\n", " final_df = []\n", " if query != \"\":\n", " queries = [q.strip() for q in query.split(\";\")]\n", " for _q in queries:\n", " _q = _q.strip()\n", " if _q != \"\":\n", " temp_filtered_df = search_table(filtered_df, _q)\n", " if len(temp_filtered_df) > 0:\n", " final_df.append(temp_filtered_df)\n", " if len(final_df) > 0:\n", " filtered_df = pd.concat(final_df)\n", " filtered_df = filtered_df.drop_duplicates( # type: ignore\n", " subset=[\"Model\", \"Precision\", \"Model sha\"]\n", " )\n", "\n", " return filtered_df\n", "\n", "def filter_models(\n", " df: pd.DataFrame,\n", " type_query: list,\n", " size_query: list,\n", " precision_query: list,\n", ") -> pd.DataFrame:\n", " # Show all models\n", " filtered_df = df\n", "\n", " type_emoji = [t[0] for t in type_query]\n", " filtered_df = filtered_df.loc[df[\"T\"].isin(type_emoji)]\n", " filtered_df = filtered_df.loc[df[\"Precision\"].isin(precision_query + [\"None\"])]\n", "\n", " numeric_interval = pd.IntervalIndex(\n", " sorted([NUMERIC_INTERVALS[s] for s in size_query]) # type: ignore\n", " )\n", " params_column = pd.to_numeric(df[\"#Params (B)\"], errors=\"coerce\")\n", " mask = params_column.apply(lambda x: any(numeric_interval.contains(x))) # type: ignore\n", " filtered_df = filtered_df.loc[mask]\n", "\n", " return filtered_df\n", "\n", "demo = gr.Blocks(css=str(abs_path / \"assets/leaderboard_data.json\"))\n", "with demo:\n", " gr.Markdown(\"\"\"Test Space of the LLM Leaderboard\"\"\", elem_classes=\"markdown-text\")\n", "\n", " with gr.Tabs(elem_classes=\"tab-buttons\") as tabs:\n", " with gr.TabItem(\"\ud83c\udfc5 LLM Benchmark\", elem_id=\"llm-benchmark-tab-table\", id=0):\n", " with gr.Row():\n", " with gr.Column():\n", " with gr.Row():\n", " search_bar = gr.Textbox(\n", " placeholder=\" \ud83d\udd0d Search for your model (separate multiple queries with `;`) and press ENTER...\",\n", " show_label=False,\n", " elem_id=\"search-bar\",\n", " )\n", " with gr.Row():\n", " shown_columns = gr.CheckboxGroup(\n", " choices=COLS,\n", " value=ON_LOAD_COLS,\n", " label=\"Select columns to show\",\n", " elem_id=\"column-select\",\n", " interactive=True,\n", " )\n", " with gr.Column(min_width=320):\n", " filter_columns_type = gr.CheckboxGroup(\n", " label=\"Model types\",\n", " choices=MODEL_TYPE,\n", " value=MODEL_TYPE,\n", " interactive=True,\n", " elem_id=\"filter-columns-type\",\n", " )\n", " filter_columns_precision = gr.CheckboxGroup(\n", " label=\"Precision\",\n", " choices=Precision,\n", " value=Precision,\n", " interactive=True,\n", " elem_id=\"filter-columns-precision\",\n", " )\n", " filter_columns_size = gr.CheckboxGroup(\n", " label=\"Model sizes (in billions of parameters)\",\n", " choices=list(NUMERIC_INTERVALS.keys()),\n", " value=list(NUMERIC_INTERVALS.keys()),\n", " interactive=True,\n", " elem_id=\"filter-columns-size\",\n", " )\n", "\n", " leaderboard_table = gr.components.Dataframe(\n", " value=df[ON_LOAD_COLS], # type: ignore\n", " headers=ON_LOAD_COLS,\n", " datatype=TYPES,\n", " elem_id=\"leaderboard-table\",\n", " interactive=False,\n", " visible=True,\n", " column_widths=[\"2%\", \"33%\"],\n", " )\n", "\n", " # Dummy leaderboard for handling the case when the user uses backspace key\n", " hidden_leaderboard_table_for_search = gr.components.Dataframe(\n", " value=invisible_df[COLS], # type: ignore\n", " headers=COLS,\n", " datatype=TYPES,\n", " visible=False,\n", " )\n", " search_bar.submit(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " )\n", " for selector in [\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " ]:\n", " selector.change(\n", " update_table,\n", " [\n", " hidden_leaderboard_table_for_search,\n", " shown_columns,\n", " filter_columns_type,\n", " filter_columns_precision,\n", " filter_columns_size,\n", " search_bar,\n", " ],\n", " leaderboard_table,\n", " queue=True,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue(default_concurrency_limit=40).launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |