huggingface-datasets-search-v2 / load_viewer_data.py
davanstrien's picture
davanstrien HF staff
load viewer data
3e2784f
raw
history blame
2.92 kB
import asyncio
import logging
import chromadb
import httpx
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map
from prep_viewer_data import prep_data
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def initialize_clients():
logger.info("Initializing clients")
chroma_client = chromadb.PersistentClient()
inference_client = InferenceClient(
"https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
)
return chroma_client, inference_client
def create_collection(chroma_client):
logger.info("Creating or getting collection")
embedding_function = SentenceTransformerEmbeddingFunction(
model_name="davanstrien/dataset-viewer-descriptions-processed-st",
trust_remote_code=True,
)
return chroma_client.create_collection(
name="dataset-viewer-descriptions",
get_or_create=True,
embedding_function=embedding_function,
metadata={"hnsw:space": "cosine"},
)
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
text = text[:8192]
return client.feature_extraction(text)
def embed_and_upsert_datasets(
dataset_rows_and_ids, collection, inference_client, batch_size=10
):
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
batch = dataset_rows_and_ids[i : i + batch_size]
ids = []
documents = []
for item in batch:
ids.append(item["dataset_id"])
documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
results = thread_map(
lambda doc: embed_card(doc, inference_client), documents, leave=False
)
collection.upsert(
ids=ids,
embeddings=[embedding.tolist()[0] for embedding in results],
)
logger.debug(f"Processed batch {i//batch_size + 1}")
async def refresh_viewer_data(sample_size=100_000, min_likes=2):
logger.info(
f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
)
chroma_client, inference_client = initialize_clients()
collection = create_collection(chroma_client)
logger.info("Preparing data")
df = await prep_data(sample_size=sample_size, min_likes=min_likes)
dataset_rows_and_ids = df.to_dicts()
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
logger.info("Refresh completed successfully")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(refresh_viewer_data())