|
|
|
import os, asyncio |
|
from huggingface_hub import InferenceClient |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" |
|
client = InferenceClient(token=HF_TOKEN) |
|
|
|
async def embed_texts(texts: list[str]) -> list[list[float]]: |
|
""" |
|
Compute embeddings for a list of texts via HF Inference API. |
|
""" |
|
def _embed(t): |
|
return client.embed(model=EMBED_MODEL, inputs=t) |
|
|
|
tasks = [asyncio.to_thread(_embed, t) for t in texts] |
|
return await asyncio.gather(*tasks) |
|
|
|
async def cluster_embeddings(embs: list[list[float]], n_clusters: int = 5) -> list[int]: |
|
""" |
|
Cluster embeddings into n_clusters, return list of cluster labels. |
|
""" |
|
kmeans = KMeans(n_clusters=n_clusters, random_state=0) |
|
return kmeans.fit_predict(embs).tolist() |
|
|