davanstrien's picture
davanstrien HF staff
add revision
a2828aa
raw
history blame
5.99 kB
import logging
import os
import platform
from datetime import datetime
from typing import List, Literal, Optional, Tuple
import chromadb
import polars as pl
import requests
import stamina
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from tqdm.contrib.concurrent import thread_map
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
EMBEDDING_MODEL_NAME = "Snowflake/snowflake-arctic-embed-m-long"
EMBEDDING_MODEL_REVISION = "ac9d0cb43661ee1f7d67b3aa63614d65a6c86463"
INFERENCE_MODEL_URL = (
"https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud"
)
DATASET_PARQUET_URL = "hf://datasets/librarian-bots/dataset_cards_with_metadata_with_embeddings/data/train-00000-of-00001.parquet"
COLLECTION_NAME = "dataset_cards"
MAX_EMBEDDING_LENGTH = 8192
def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
logger.info(f"Using save path: {path}")
return path
SAVE_PATH = get_save_path()
def get_chroma_client():
logger.info("Initializing Chroma client")
return chromadb.PersistentClient(path=SAVE_PATH)
def get_embedding_function():
logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_NAME,
trust_remote_code=True,
revision=EMBEDDING_MODEL_REVISION,
)
def get_collection(chroma_client, embedding_function):
logger.info(f"Getting or creating collection: {COLLECTION_NAME}")
return chroma_client.create_collection(
name=COLLECTION_NAME, get_or_create=True, embedding_function=embedding_function
)
def get_last_modified_in_collection(collection) -> datetime | None:
logger.info("Fetching last modified date from collection")
all_items = collection.get(include=["metadatas"])
if last_modified := [
datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"]
]:
last_mod = max(last_modified)
logger.info(f"Last modified date: {last_mod}")
return last_mod
else:
logger.info("No last modified date found")
return None
def parse_markdown_column(
df: pl.DataFrame, markdown_column: str, dataset_id_column: str
) -> pl.DataFrame:
logger.info("Parsing markdown column")
return df.with_columns(
parsed_markdown=(
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars()
),
prepended_markdown=(
pl.concat_str(
[
pl.lit("Dataset ID "),
pl.col(dataset_id_column).cast(pl.Utf8),
pl.lit("\n\n"),
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars(),
]
)
),
)
def load_cards(
min_len: int = 50,
min_likes: int | None = None,
last_modified: Optional[datetime] = None,
) -> Optional[Tuple[List[str], List[str], List[datetime]]]:
logger.info(
f"Loading cards with min_len={min_len}, min_likes={min_likes}, last_modified={last_modified}"
)
df = pl.read_parquet(DATASET_PARQUET_URL)
df = parse_markdown_column(df, "card", "datasetId")
df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len"))
df = df.filter(pl.col("card_len") > min_len)
if min_likes:
df = df.filter(pl.col("likes") > min_likes)
if last_modified:
df = df.filter(pl.col("last_modified") > last_modified)
if len(df) == 0:
logger.info("No cards found matching criteria")
return None
cards = df.get_column("prepended_markdown").to_list()
model_ids = df.get_column("datasetId").to_list()
last_modifieds = df.get_column("last_modified").to_list()
logger.info(f"Loaded {len(cards)} cards")
return cards, model_ids, last_modifieds
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
text = text[:MAX_EMBEDDING_LENGTH]
return client.feature_extraction(text)
def get_inference_client():
logger.info(f"Initializing inference client with model: {INFERENCE_MODEL_URL}")
return InferenceClient(
model=INFERENCE_MODEL_URL,
token=HF_TOKEN,
)
def refresh_data(min_len: int = 200, min_likes: Optional[int] = None):
logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
chroma_client = get_chroma_client()
embedding_function = get_embedding_function()
collection = get_collection(chroma_client, embedding_function)
most_recent = get_last_modified_in_collection(collection)
if data := load_cards(
min_len=min_len, min_likes=min_likes, last_modified=most_recent
):
_create_and_upsert_embeddings(data, collection)
else:
logger.info("No new data to refresh")
def _create_and_upsert_embeddings(data, collection):
cards, model_ids, last_modifieds = data
logger.info("Embedding cards...")
inference_client = get_inference_client()
results = thread_map(lambda card: embed_card(card, inference_client), cards)
logger.info(f"Upserting {len(model_ids)} items to collection")
collection.upsert(
ids=model_ids,
embeddings=[embedding.tolist()[0] for embedding in results],
metadatas=[{"last_modified": str(lm)} for lm in last_modifieds],
)
logger.info("Data refresh completed successfully")
if __name__ == "__main__":
refresh_data()