Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
3e2784f
1
Parent(s):
551f450
load viewer data
Browse files- load_viewer_data.py +88 -0
- prep_viewer_data.py +158 -0
load_viewer_data.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import chromadb
|
5 |
+
import httpx
|
6 |
+
import requests
|
7 |
+
import stamina
|
8 |
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
9 |
+
from huggingface_hub import InferenceClient
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from tqdm.contrib.concurrent import thread_map
|
12 |
+
|
13 |
+
from prep_viewer_data import prep_data
|
14 |
+
|
15 |
+
# Set up logging
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
logger.setLevel(logging.INFO)
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_clients():
|
21 |
+
logger.info("Initializing clients")
|
22 |
+
chroma_client = chromadb.PersistentClient()
|
23 |
+
inference_client = InferenceClient(
|
24 |
+
"https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
|
25 |
+
)
|
26 |
+
return chroma_client, inference_client
|
27 |
+
|
28 |
+
|
29 |
+
def create_collection(chroma_client):
|
30 |
+
logger.info("Creating or getting collection")
|
31 |
+
embedding_function = SentenceTransformerEmbeddingFunction(
|
32 |
+
model_name="davanstrien/dataset-viewer-descriptions-processed-st",
|
33 |
+
trust_remote_code=True,
|
34 |
+
)
|
35 |
+
return chroma_client.create_collection(
|
36 |
+
name="dataset-viewer-descriptions",
|
37 |
+
get_or_create=True,
|
38 |
+
embedding_function=embedding_function,
|
39 |
+
metadata={"hnsw:space": "cosine"},
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
|
44 |
+
def embed_card(text, client):
|
45 |
+
text = text[:8192]
|
46 |
+
return client.feature_extraction(text)
|
47 |
+
|
48 |
+
|
49 |
+
def embed_and_upsert_datasets(
|
50 |
+
dataset_rows_and_ids, collection, inference_client, batch_size=10
|
51 |
+
):
|
52 |
+
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
|
53 |
+
for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
|
54 |
+
batch = dataset_rows_and_ids[i : i + batch_size]
|
55 |
+
ids = []
|
56 |
+
documents = []
|
57 |
+
for item in batch:
|
58 |
+
ids.append(item["dataset_id"])
|
59 |
+
documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
|
60 |
+
results = thread_map(
|
61 |
+
lambda doc: embed_card(doc, inference_client), documents, leave=False
|
62 |
+
)
|
63 |
+
collection.upsert(
|
64 |
+
ids=ids,
|
65 |
+
embeddings=[embedding.tolist()[0] for embedding in results],
|
66 |
+
)
|
67 |
+
logger.debug(f"Processed batch {i//batch_size + 1}")
|
68 |
+
|
69 |
+
|
70 |
+
async def refresh_viewer_data(sample_size=100_000, min_likes=2):
|
71 |
+
logger.info(
|
72 |
+
f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
|
73 |
+
)
|
74 |
+
chroma_client, inference_client = initialize_clients()
|
75 |
+
collection = create_collection(chroma_client)
|
76 |
+
|
77 |
+
logger.info("Preparing data")
|
78 |
+
df = await prep_data(sample_size=sample_size, min_likes=min_likes)
|
79 |
+
dataset_rows_and_ids = df.to_dicts()
|
80 |
+
|
81 |
+
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
|
82 |
+
embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
|
83 |
+
logger.info("Refresh completed successfully")
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
logging.basicConfig(level=logging.INFO)
|
88 |
+
asyncio.run(refresh_viewer_data())
|
prep_viewer_data.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
import httpx
|
6 |
+
import polars as pl
|
7 |
+
from huggingface_hub import list_datasets
|
8 |
+
from tqdm import tqdm
|
9 |
+
from tqdm.asyncio import tqdm_asyncio
|
10 |
+
|
11 |
+
# Initialize the HTTP client
|
12 |
+
client = httpx.AsyncClient(timeout=60, http2=True)
|
13 |
+
|
14 |
+
|
15 |
+
async def generate_dataset_prompt(dataset_name, num_rows=2):
|
16 |
+
try:
|
17 |
+
base_url = "https://datasets-server.huggingface.co"
|
18 |
+
|
19 |
+
# Get splits and configs
|
20 |
+
splits_url = f"{base_url}/splits?dataset={dataset_name}"
|
21 |
+
splits_response = await client.get(splits_url)
|
22 |
+
splits_data = splits_response.json()
|
23 |
+
|
24 |
+
if not splits_data.get("splits"):
|
25 |
+
return None
|
26 |
+
|
27 |
+
# Get the first config and split
|
28 |
+
first_split = splits_data["splits"][0]
|
29 |
+
config_name = first_split["config"]
|
30 |
+
split_name = first_split["split"]
|
31 |
+
|
32 |
+
# Get dataset info for the specific config
|
33 |
+
info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}"
|
34 |
+
info_response = await client.get(info_url)
|
35 |
+
info_data = info_response.json()
|
36 |
+
|
37 |
+
# Get first rows for the specific config and split
|
38 |
+
first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}"
|
39 |
+
first_rows_response = await client.get(first_rows_url)
|
40 |
+
first_rows_data = first_rows_response.json()
|
41 |
+
|
42 |
+
# Get size information
|
43 |
+
size_url = f"{base_url}/size?dataset={dataset_name}"
|
44 |
+
size_response = await client.get(size_url)
|
45 |
+
size_data = size_response.json()
|
46 |
+
|
47 |
+
# Extract relevant information
|
48 |
+
dataset_info = info_data.get("dataset_info", {})
|
49 |
+
features = dataset_info.get("features", {})
|
50 |
+
splits = dataset_info.get("splits", {})
|
51 |
+
|
52 |
+
# Calculate total examples and size
|
53 |
+
total_examples = sum(split.get("num_examples", 0) for split in splits.values())
|
54 |
+
total_size = (
|
55 |
+
size_data.get("size", {})
|
56 |
+
.get("dataset", {})
|
57 |
+
.get("num_bytes_original_files", 0)
|
58 |
+
)
|
59 |
+
|
60 |
+
# Format features
|
61 |
+
def format_feature(name, details):
|
62 |
+
if isinstance(details, dict):
|
63 |
+
feature_type = details.get(
|
64 |
+
"dtype", details.get("_type", "unknown type")
|
65 |
+
)
|
66 |
+
elif isinstance(details, list):
|
67 |
+
feature_type = "list"
|
68 |
+
else:
|
69 |
+
feature_type = str(type(details).__name__)
|
70 |
+
return f"- {name} ({feature_type})"
|
71 |
+
|
72 |
+
formatted_features = "\n".join(
|
73 |
+
format_feature(name, details) for name, details in features.items()
|
74 |
+
)
|
75 |
+
|
76 |
+
# Format sample data (specified number of rows)
|
77 |
+
sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2)
|
78 |
+
|
79 |
+
# Create the formatted prompt
|
80 |
+
prompt = f"""
|
81 |
+
Dataset: "{dataset_name}"
|
82 |
+
|
83 |
+
Features:
|
84 |
+
{formatted_features}
|
85 |
+
|
86 |
+
Splits and Configs:
|
87 |
+
{', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])}
|
88 |
+
|
89 |
+
Size Statistics:
|
90 |
+
Total Examples: {total_examples}
|
91 |
+
Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())}
|
92 |
+
|
93 |
+
Data Sample ({num_rows} rows out of {total_examples} total):
|
94 |
+
{sample_data}
|
95 |
+
"""
|
96 |
+
|
97 |
+
return prompt.strip()
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error for {dataset_name}: {e}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
|
103 |
+
async def process_batch(batch):
|
104 |
+
results = await tqdm_asyncio.gather(
|
105 |
+
*[generate_dataset_prompt(dataset) for dataset in batch], leave=False
|
106 |
+
)
|
107 |
+
return [
|
108 |
+
(dataset_id, prompt)
|
109 |
+
for dataset_id, prompt in zip(batch, results)
|
110 |
+
if prompt is not None
|
111 |
+
]
|
112 |
+
|
113 |
+
|
114 |
+
async def prep_data(sample_size=200_000, min_likes=1):
|
115 |
+
# Load the dataset containing dataset IDs
|
116 |
+
df = pl.read_parquet(
|
117 |
+
"hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet"
|
118 |
+
)
|
119 |
+
in_train_or_test = set(df["dataset_id"].unique().to_list())
|
120 |
+
|
121 |
+
# Get all datasets
|
122 |
+
datasets = [
|
123 |
+
dataset for dataset in list_datasets() if dataset.id not in in_train_or_test
|
124 |
+
]
|
125 |
+
# filter to datasets with 1 or more likes
|
126 |
+
if min_likes:
|
127 |
+
datasets = [dataset for dataset in datasets if dataset.likes >= min_likes]
|
128 |
+
datasets = [dataset.id for dataset in datasets]
|
129 |
+
# Sample datasets (adjust the number as needed)
|
130 |
+
datasets = random.sample(datasets, min(sample_size, len(datasets)))
|
131 |
+
|
132 |
+
# Process datasets in batches of 100
|
133 |
+
batch_size = 500
|
134 |
+
all_results = []
|
135 |
+
|
136 |
+
for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"):
|
137 |
+
batch = datasets[i : i + batch_size]
|
138 |
+
batch_results = await process_batch(batch)
|
139 |
+
all_results.extend(batch_results)
|
140 |
+
|
141 |
+
# Optional: Save intermediate results
|
142 |
+
if len(all_results) % 1000 == 0:
|
143 |
+
intermediate_df = pl.DataFrame(
|
144 |
+
{
|
145 |
+
"dataset_id": [row[0] for row in all_results],
|
146 |
+
"formatted_prompt": [row[1] for row in all_results],
|
147 |
+
}
|
148 |
+
)
|
149 |
+
intermediate_df.write_parquet(
|
150 |
+
f"dataset_prompts_intermediate_{len(all_results)}.parquet"
|
151 |
+
)
|
152 |
+
|
153 |
+
return pl.DataFrame(
|
154 |
+
{
|
155 |
+
"dataset_id": [row[0] for row in all_results],
|
156 |
+
"formatted_prompt": [row[1] for row in all_results],
|
157 |
+
}
|
158 |
+
)
|