Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,180 Bytes
8a9d0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import polars as pl
from search import search
from table import df_orig
COLUMNS_MCP = [
"title",
"authors",
"abstract",
"cvf_page_url",
"pdf_url",
"supp_url",
"arxiv_id",
"paper_page",
"bibtex",
"space_ids",
"model_ids",
"dataset_ids",
"upvotes",
"num_comments",
"project_page",
"github",
"row_index",
]
DEFAULT_COLUMNS_MCP = [
"title",
"authors",
"abstract",
"cvf_page_url",
"pdf_url",
"arxiv_id",
"project_page",
"github",
"row_index",
]
df_mcp = df_orig.rename({"cvf": "cvf_page_url", "paper_id": "row_index"}).select(COLUMNS_MCP)
def search_papers(
search_query: str,
candidate_pool_size: int,
num_results: int,
columns: list[str],
) -> list[dict]:
"""Searches CVPR 2025 papers relevant to a user query in English.
This function performs a semantic search over CVPR 2025 papers.
It uses a dual-stage retrieval process:
- First, it retrieves `candidate_pool_size` papers using dense vector similarity.
- Then, it re-ranks them with a cross-encoder model to select the top `num_results` most relevant papers.
- The search results are returned as a list of dictionaries.
Note:
The search query must be written in English. Queries in other languages are not supported.
Args:
search_query (str): The natural language query input by the user. Must be in English.
candidate_pool_size (int): Number of candidate papers to retrieve using the dense vector model.
num_results (int): Final number of top-ranked papers to return after re-ranking.
columns (list[str]): The columns to select from the DataFrame.
Returns:
list[dict]: A list of dictionaries of the top-ranked papers matching the query, sorted by relevance.
"""
if not search_query:
raise ValueError("Search query cannot be empty")
if num_results > candidate_pool_size:
raise ValueError("Number of results must be less than or equal to candidate pool size")
df = df_mcp.clone()
results = search(search_query, candidate_pool_size, num_results)
df = pl.DataFrame(results).rename({"paper_id": "row_index"}).join(df, on="row_index", how="inner")
df = df.sort("ce_score", descending=True)
return df.select(columns).to_dicts()
def get_metadata(row_index: int) -> dict:
"""Returns a dictionary of metadata for a CVPR 2025 paper at the given table row index.
Args:
row_index (int): The index of the paper in the internal paper list table.
Returns:
dict: A dictionary containing metadata for the corresponding paper.
"""
return df_mcp.filter(pl.col("row_index") == row_index).to_dicts()[0]
def get_table(columns: list[str]) -> list[dict]:
"""Returns a list of dictionaries of all CVPR 2025 papers.
Args:
columns (list[str]): The columns to select from the DataFrame.
Returns:
list[dict]: A list of dictionaries of all CVPR 2025 papers.
"""
return df_mcp.select(columns).to_dicts()
with gr.Blocks() as demo:
search_query = gr.Textbox(label="Search", submit_btn=True)
candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=500, step=1, value=200)
num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100)
column_names = gr.CheckboxGroup(label="Columns", choices=COLUMNS_MCP, value=DEFAULT_COLUMNS_MCP)
row_index = gr.Slider(label="Row Index", minimum=0, maximum=len(df_mcp) - 1, step=1, value=0)
out = gr.JSON()
search_papers_btn = gr.Button("Search Papers")
get_metadata_btn = gr.Button("Get Metadata")
get_table_btn = gr.Button("Get Table")
search_papers_btn.click(
fn=search_papers,
inputs=[search_query, candidate_pool_size, num_results, column_names],
outputs=out,
)
get_metadata_btn.click(
fn=get_metadata,
inputs=row_index,
outputs=out,
)
get_table_btn.click(
fn=get_table,
inputs=column_names,
outputs=out,
)
if __name__ == "__main__":
demo.launch(mcp_server=True)
|