|
import streamlit as st |
|
from st_utils import bm25_search, semantic_search, hf_api, paginator |
|
from huggingface_hub import ModelSearchArguments |
|
import webbrowser |
|
from numerize.numerize import numerize |
|
import math |
|
|
|
st.set_page_config( |
|
page_title="TNR LIBRERY", |
|
page_icon="♾️", |
|
layout="wide", |
|
initial_sidebar_state="auto", |
|
) |
|
|
|
|
|
search_backend = st.sidebar.selectbox( |
|
"Search method", |
|
["semantic", "bm25", "hfapi"], |
|
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x], |
|
) |
|
limit_results = int(st.sidebar.number_input("Limit results", min_value=0, value=10)) |
|
sort_by = st.sidebar.selectbox( |
|
"Sort by", |
|
[None, "downloads", "likes", "lastModified"], |
|
format_func=lambda x: {None: "Relevance", "downloads": "Most downloads", "likes": "Most likes", "lastModified": "Recently updated"}[x], |
|
) |
|
|
|
st.sidebar.markdown("# Filters") |
|
args = ModelSearchArguments() |
|
library = st.sidebar.multiselect( |
|
"Library", args.library.values(), format_func=lambda x: {v: k for k, v in args.library.items()}[x] |
|
) |
|
task = st.sidebar.multiselect( |
|
"Task", args.pipeline_tag.values(), format_func=lambda x: {v: k for k, v in args.pipeline_tag.items()}[x] |
|
) |
|
|
|
|
|
st.markdown( |
|
"<h1 style='text-align: center; '>♾️ TNR LIBRERY</h1>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
search_query = st.text_input("Search for a model in HuggingFace", value="", max_chars=None, key=None, type="default") |
|
|
|
if search_query != "": |
|
filters = { |
|
"library": library, |
|
"task": task, |
|
} |
|
if search_backend == "hfapi": |
|
res = hf_api(search_query, limit_results, sort_by, filters) |
|
elif search_backend == "semantic": |
|
res = semantic_search(search_query, limit_results, sort_by, filters) |
|
elif search_backend == "bm25": |
|
res = bm25_search(search_query, limit_results, sort_by, filters) |
|
hit_list, hits_count = res["hits"], res["count"] |
|
hit_list = [ |
|
{ |
|
"modelId": hit["modelId"], |
|
"tags": hit["tags"], |
|
"downloads": hit["downloads"], |
|
"likes": hit["likes"], |
|
"readme": hit.get("readme", None), |
|
} |
|
for hit in hit_list |
|
] |
|
|
|
if hit_list: |
|
st.write(f"Search results ({hits_count}):") |
|
|
|
if hits_count > 100: |
|
shown_results = 100 |
|
else: |
|
shown_results = hits_count |
|
|
|
for i, hit in paginator( |
|
f"Select results (showing {shown_results} of {hits_count} results)", |
|
hit_list, |
|
): |
|
col1, col2, col3 = st.columns([5, 1, 1]) |
|
col1.metric("Model", hit["modelId"]) |
|
col2.metric("N° downloads", numerize(hit["downloads"]) if hit["downloads"] and not math.isnan(hit["downloads"]) else "N/A") |
|
col3.metric("N° likes", numerize(hit["likes"]) if hit["likes"] and not math.isnan(hit["likes"]) else "N/A") |
|
st.button( |
|
f"View model on ♾️", |
|
on_click=lambda hit=hit: webbrowser.open(f"https://libt.lpmotortest.com", new=2), |
|
key=f"{i}-{hit['modelId']}", |
|
) |
|
st.write(f"**Tags:** {' • '.join(hit['tags'])}") |
|
|
|
if hit["readme"]: |
|
with st.expander("See README"): |
|
st.write(hit["readme"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
|
else: |
|
st.write(f"No Search results 😔") |
|
|
|
st.markdown( |
|
"<h6 style='text-align: center; color: #808080;'>Made with ❤️ By <a href='https://tnr.lpmotortest.com'>TNR Studio</a> - Checkout complete project <a href='https://bit.ly/libtrefresh'>here</a></h6>", |
|
unsafe_allow_html=True, |
|
) |