|
import json |
|
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments |
|
from pprint import pprint |
|
from hf_search import HFSearch |
|
import streamlit as st |
|
import itertools |
|
|
|
from pbr.version import VersionInfo |
|
print("hf_search version:", VersionInfo('hf_search').version_string()) |
|
|
|
hf_search = HFSearch(top_k=200) |
|
|
|
@st.cache |
|
def hf_api(query, limit=5, sort=None, filters={}): |
|
print("query", query) |
|
print("filters", filters) |
|
print("limit", limit) |
|
print("sort", sort) |
|
|
|
api = HfApi() |
|
filt = ModelFilter( |
|
task=filters["task"], |
|
library=filters["library"], |
|
) |
|
models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True) |
|
hits = [] |
|
for model in models: |
|
model = model.__dict__ |
|
hits.append( |
|
{ |
|
"modelId": model.get("modelId"), |
|
"tags": model.get("tags"), |
|
"downloads": model.get("downloads"), |
|
"likes": model.get("likes"), |
|
} |
|
) |
|
count = len(hits) |
|
if len(hits) > limit: |
|
hits = hits[:limit] |
|
return {"hits": hits, "count": count} |
|
|
|
|
|
@st.cache |
|
def semantic_search(query, limit=5, sort=None, filters={}): |
|
print("query", query) |
|
print("filters", filters) |
|
print("limit", limit) |
|
print("sort", sort) |
|
|
|
hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters) |
|
hits = [ |
|
{ |
|
"modelId": hit["modelId"], |
|
"tags": hit["tags"], |
|
"downloads": hit["downloads"], |
|
"likes": hit["likes"], |
|
"readme": hit.get("readme", None), |
|
} |
|
for hit in hits |
|
] |
|
return {"hits": hits, "count": len(hits)} |
|
|
|
|
|
@st.cache |
|
def bm25_search(query, limit=5, sort=None, filters={}): |
|
print("query", query) |
|
print("filters", filters) |
|
print("limit", limit) |
|
print("sort", sort) |
|
|
|
|
|
hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters) |
|
hits = [ |
|
{ |
|
"modelId": hit["modelId"], |
|
"tags": hit["tags"], |
|
"downloads": hit["downloads"], |
|
"likes": hit["likes"], |
|
"readme": hit.get("readme", None), |
|
} |
|
for hit in hits |
|
] |
|
hits = [ |
|
hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]] |
|
] |
|
return {"hits": hits, "count": len(hits)} |
|
|
|
|
|
def paginator(label, articles, articles_per_page=10, on_sidebar=True): |
|
|
|
"""Lets the user paginate a set of article. |
|
Parameters |
|
---------- |
|
label : str |
|
The label to display over the pagination widget. |
|
article : Iterator[Any] |
|
The articles to display in the paginator. |
|
articles_per_page: int |
|
The number of articles to display per page. |
|
on_sidebar: bool |
|
Whether to display the paginator widget on the sidebar. |
|
|
|
Returns |
|
------- |
|
Iterator[Tuple[int, Any]] |
|
An iterator over *only the article on that page*, including |
|
the item's index. |
|
""" |
|
|
|
|
|
if on_sidebar: |
|
location = st.sidebar.empty() |
|
else: |
|
location = st.empty() |
|
|
|
|
|
articles = list(articles) |
|
n_pages = (len(articles) - 1) // articles_per_page + 1 |
|
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}" |
|
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func) |
|
|
|
|
|
min_index = page_number * articles_per_page |
|
max_index = min_index + articles_per_page |
|
|
|
return itertools.islice(enumerate(articles), min_index, max_index) |
|
|