Spaces:
Sleeping
Sleeping
import json | |
import os | |
from datetime import datetime | |
import dotenv | |
import lancedb | |
from datasets import load_dataset | |
from fasthtml.common import * # noqa | |
from huggingface_hub import login, whoami | |
from rerankers import Reranker | |
dotenv.load_dotenv() | |
login(token=os.environ.get("HF_TOKEN")) | |
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] | |
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" | |
abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"] | |
article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"] | |
ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert") | |
uri = "data/zotero-fts" | |
db = lancedb.connect(uri) | |
id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds} | |
id2content = {example["arxiv_id"]: example["contents"] for example in article_ds} | |
id2title = {example["arxiv_id"]: example["title"] for example in article_ds} | |
arxiv_ids = set(list(id2abstract.keys())) | |
data = [] | |
for arxiv_id in arxiv_ids: | |
abstract = id2abstract[arxiv_id] | |
title = id2title[arxiv_id] | |
full_text = title | |
for item in id2content[arxiv_id]: | |
full_text += f"{item['title']}\n\n{item['content']}" | |
data.append( | |
{ | |
"arxiv_id": arxiv_id, | |
"title": title, | |
"abstract": abstract, | |
"full_text": full_text, | |
} | |
) | |
table = db.create_table("articles", data=data, mode="overwrite") | |
table.create_fts_index("full_text", replace=True) | |
# format results ---- | |
def _format_results(results): | |
ret = [] | |
for result in results: | |
arx_id = result["arxiv_id"] | |
title = result["title"] | |
abstract = result["abstract"] | |
if "Abstract\n\n" in abstract: | |
abstract = abstract.split("Abstract\n\n")[-1] | |
this_ex = { | |
"title": title, | |
"url": f"https://arxiv.org/abs/{arx_id}", | |
"abstract": abstract, | |
} | |
ret.append(this_ex) | |
return ret | |
def retrieve_and_rerank(query, k=5): | |
# retrieve --- | |
n_fetch = 25 | |
retrieved = ( | |
table.search(query, vector_column_name="", query_type="fts") | |
.limit(n_fetch) | |
.select(["arxiv_id", "title", "abstract"]) | |
.to_list() | |
) | |
# re-rank | |
docs = [f"{item['title']} {item['abstract']}" for item in retrieved] | |
results = ranker.rank(query=query, docs=docs) | |
ranked_doc_ids = [] | |
for result in results[:k]: | |
ranked_doc_ids.append(result.doc_id) | |
final_results = [retrieved[idx] for idx in ranked_doc_ids] | |
final_results = _format_results(final_results) | |
return final_results | |
########################################################################### | |
# FastHTML app ----- | |
########################################################################### | |
style = Style(""" | |
:root { | |
color-scheme: dark; | |
} | |
body { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
line-height: 1.6; | |
} | |
#query { | |
width: 100%; | |
margin-bottom: 1rem; | |
} | |
#search-form button { | |
width: 100%; | |
} | |
#search-results, #log-entries { | |
margin-top: 2rem; | |
} | |
.log-entry { | |
border: 1px solid #ccc; | |
padding: 10px; | |
margin-bottom: 10px; | |
} | |
.log-entry pre { | |
white-space: pre-wrap; | |
word-wrap: break-word; | |
} | |
.htmx-indicator { | |
display: none; | |
} | |
.htmx-request .htmx-indicator { | |
display: inline-block; | |
} | |
.spinner { | |
display: inline-block; | |
width: 2.5em; | |
height: 2.5em; | |
border: 0.3em solid rgba(255,255,255,.3); | |
border-radius: 50%; | |
border-top-color: #fff; | |
animation: spin 1s ease-in-out infinite; | |
margin-left: 10px; | |
vertical-align: middle; | |
} | |
@keyframes spin { | |
to { transform: rotate(360deg); } | |
} | |
.searching-text { | |
font-size: 1.2em; | |
font-weight: bold; | |
color: #fff; | |
margin-right: 10px; | |
vertical-align: middle; | |
} | |
""") | |
# get the fast app and route | |
app, rt = fast_app(hdrs=(style,)) | |
# Initialize a database to store search logs -- | |
db = database("log_data/search_logs.db") | |
search_logs = db.t.search_logs | |
if search_logs not in db.t: | |
search_logs.create( | |
id=int, | |
timestamp=str, | |
query=str, | |
results=str, | |
pk="id", | |
) | |
SearchLog = search_logs.dataclass() | |
def insert_log_entry(log_entry): | |
"Insert a log entry into the database" | |
return search_logs.insert( | |
SearchLog( | |
timestamp=log_entry["timestamp"].isoformat(), | |
query=log_entry["query"], | |
results=json.dumps(log_entry["results"]), | |
) | |
) | |
async def get(): | |
query_form = Form( | |
Textarea(id="query", name="query", placeholder="Enter your query..."), | |
Button("Submit", type="submit"), | |
Div( | |
Span("Searching...", cls="searching-text htmx-indicator"), | |
Span(cls="spinner htmx-indicator"), | |
cls="indicator-container", | |
), | |
id="search-form", | |
hx_post="/search", | |
hx_target="#search-results", | |
hx_indicator=".indicator-container", | |
) | |
results_div = Div(Div(id="search-results", cls="results-container")) | |
view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") | |
return Titled( | |
"Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") | |
) | |
def SearchResult(result): | |
"Custom component for displaying a search result" | |
return Card( | |
H4(A(result["title"], href=result["url"], target="_blank")), | |
P(result["abstract"]), | |
footer=A("Read more →", href=result["url"], target="_blank"), | |
) | |
def log_query_and_results(query, results): | |
log_entry = { | |
"timestamp": datetime.now(), | |
"query": query, | |
"results": [{"title": r["title"], "url": r["url"]} for r in results], | |
} | |
insert_log_entry(log_entry) | |
async def post(query: str): | |
results = retrieve_and_rerank(query) | |
log_query_and_results(query, results) | |
return Div(*[SearchResult(r) for r in results], id="search-results") | |
def LogEntry(entry): | |
return Div( | |
H4(f"Query: {entry.query}"), | |
P(f"Timestamp: {entry.timestamp}"), | |
H5("Results:"), | |
Pre(entry.results), | |
cls="log-entry", | |
) | |
async def get(): | |
logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs | |
log_entries = [LogEntry(log) for log in logs] | |
return Titled( | |
"Logs", | |
Div( | |
H2("Recent Search Logs"), | |
Div(*log_entries, id="log-entries"), | |
A("Back to Search", href="/", cls="back-link"), | |
cls="container", | |
), | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |
# run_uv() | |