File size: 2,923 Bytes
3e2784f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import asyncio
import logging

import chromadb
import httpx
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map

from prep_viewer_data import prep_data

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def initialize_clients():
    logger.info("Initializing clients")
    chroma_client = chromadb.PersistentClient()
    inference_client = InferenceClient(
        "https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
    )
    return chroma_client, inference_client


def create_collection(chroma_client):
    logger.info("Creating or getting collection")
    embedding_function = SentenceTransformerEmbeddingFunction(
        model_name="davanstrien/dataset-viewer-descriptions-processed-st",
        trust_remote_code=True,
    )
    return chroma_client.create_collection(
        name="dataset-viewer-descriptions",
        get_or_create=True,
        embedding_function=embedding_function,
        metadata={"hnsw:space": "cosine"},
    )


@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
    text = text[:8192]
    return client.feature_extraction(text)


def embed_and_upsert_datasets(
    dataset_rows_and_ids, collection, inference_client, batch_size=10
):
    logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
    for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
        batch = dataset_rows_and_ids[i : i + batch_size]
        ids = []
        documents = []
        for item in batch:
            ids.append(item["dataset_id"])
            documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
        results = thread_map(
            lambda doc: embed_card(doc, inference_client), documents, leave=False
        )
        collection.upsert(
            ids=ids,
            embeddings=[embedding.tolist()[0] for embedding in results],
        )
        logger.debug(f"Processed batch {i//batch_size + 1}")


async def refresh_viewer_data(sample_size=100_000, min_likes=2):
    logger.info(
        f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
    )
    chroma_client, inference_client = initialize_clients()
    collection = create_collection(chroma_client)

    logger.info("Preparing data")
    df = await prep_data(sample_size=sample_size, min_likes=min_likes)
    dataset_rows_and_ids = df.to_dicts()

    logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
    embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
    logger.info("Refresh completed successfully")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    asyncio.run(refresh_viewer_data())