import gradio as gr from typing import Dict, List, Optional, TypedDict from nlp4web_codebase.ir.data_loaders.sciq import load_sciq from bm25 import BM25Index, BM25Retriever sciq = load_sciq() bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=0.8, # Tuned on dev wrt. MAP@10 b=0.6, # Tuned on dev wrt. MAP@10 ) bm25_index.save("output/bm25_sciq_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index") class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus} def search(query: str) -> List[Hit]: ranking: Dict[str, float] = bm25_retriever.retrieve(query) # Sort the ranking by score in descending order sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True) hits = [] for cid, score in sorted_ranking: hits.append(Hit(cid=cid, score=score, text=cid2doc[cid])) return hits demo = gr.Interface( fn=search, inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), outputs="text", title="BM25 Retriever Search", description="Search using a BM25 retriever on [SciQ](https://huggingface.co/datasets/allenai/sciq) and return top-10 ranked documents with scores.", ) ## YOUR_CODE_ENDS_HERE # print(demo.local_url) # demo.launch() # start a thread to run the demo # import threading # thread = threading.Thread(target=demo.launch) # thread.start() # import time # time.sleep(5) # print(demo.local_url) # print(demo.local_api_url) # thread.join()