Richard Guo
commited on
Commit
·
9a64205
1
Parent(s):
1bacad4
working auto dataset upload
Browse files- app.py +5 -3
- build_map.py +243 -0
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, Request
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
from fastapi.templating import Jinja2Templates
|
4 |
from typing import Optional
|
@@ -15,15 +15,17 @@ templates = Jinja2Templates(directory="templates")
|
|
15 |
# Create a Pydantic model for the form data
|
16 |
class DatasetForm(BaseModel):
|
17 |
dataset_name: str
|
18 |
-
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
21 |
@app.get("/", response_class=HTMLResponse)
|
22 |
async def read_form(request: Request):
|
23 |
# Render the form.html template
|
24 |
return templates.TemplateResponse("form.html", {"request": request})
|
25 |
|
26 |
-
|
27 |
@app.post("/submit_form")
|
28 |
async def form_post(form_data: DatasetForm):
|
29 |
# Do something with form_data
|
|
|
1 |
+
from fastapi import FastAPI, Request, WebSocket
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
from fastapi.templating import Jinja2Templates
|
4 |
from typing import Optional
|
|
|
15 |
# Create a Pydantic model for the form data
|
16 |
class DatasetForm(BaseModel):
|
17 |
dataset_name: str
|
|
|
18 |
|
19 |
|
20 |
+
|
21 |
+
def long_running_function():
|
22 |
+
pass
|
23 |
+
|
24 |
@app.get("/", response_class=HTMLResponse)
|
25 |
async def read_form(request: Request):
|
26 |
# Render the form.html template
|
27 |
return templates.TemplateResponse("form.html", {"request": request})
|
28 |
|
|
|
29 |
@app.post("/submit_form")
|
30 |
async def form_post(form_data: DatasetForm):
|
31 |
# Do something with form_data
|
build_map.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nomic
|
2 |
+
import pandas as pd
|
3 |
+
from tqdm import tqdm
|
4 |
+
from datasets import load_dataset, \
|
5 |
+
get_dataset_split_names, \
|
6 |
+
get_dataset_config_names, \
|
7 |
+
ClassLabel, utils
|
8 |
+
|
9 |
+
utils.logging.set_verbosity_error()
|
10 |
+
import pyarrow as pa
|
11 |
+
from dateutil.parser import parse
|
12 |
+
import time
|
13 |
+
|
14 |
+
|
15 |
+
def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20):
|
16 |
+
# take a sample of points
|
17 |
+
dataset = dataset_dict["first_split_dataset"]
|
18 |
+
sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples))
|
19 |
+
features = dataset.features
|
20 |
+
|
21 |
+
numeric_fields = []
|
22 |
+
string_fields = []
|
23 |
+
bool_fields = []
|
24 |
+
list_fields = []
|
25 |
+
label_fields = []
|
26 |
+
categorical_fields = []
|
27 |
+
datetime_fields = []
|
28 |
+
uncategorized_fields = []
|
29 |
+
|
30 |
+
if unique_cutoff < 1:
|
31 |
+
unique_cutoff = unique_cutoff*len(sample)
|
32 |
+
|
33 |
+
for field, dtype in dataset_dict["schema"].items():
|
34 |
+
try:
|
35 |
+
num_unique = sample[field].nunique()
|
36 |
+
except:
|
37 |
+
num_unique = len(sample)
|
38 |
+
|
39 |
+
if dtype == "string":
|
40 |
+
if num_unique < unique_cutoff:
|
41 |
+
categorical_fields.append(field)
|
42 |
+
else:
|
43 |
+
is_datetime = True
|
44 |
+
for row in sample:
|
45 |
+
try:
|
46 |
+
parse(row[field], fuzzy=False)
|
47 |
+
except:
|
48 |
+
is_datetime = False
|
49 |
+
break
|
50 |
+
if is_datetime:
|
51 |
+
datetime_fields.append(field)
|
52 |
+
else:
|
53 |
+
string_fields.append(field)
|
54 |
+
|
55 |
+
elif dtype in ("float"):
|
56 |
+
numeric_fields.append(field)
|
57 |
+
|
58 |
+
elif dtype in ("int64", "int32", "int16", "int8"):
|
59 |
+
if features is not None and field in features and isinstance(features[field], ClassLabel):
|
60 |
+
label_fields.append(field)
|
61 |
+
elif num_unique < unique_cutoff:
|
62 |
+
categorical_fields.append(field)
|
63 |
+
else:
|
64 |
+
numeric_fields.append(field)
|
65 |
+
|
66 |
+
elif dtype == "bool":
|
67 |
+
bool_fields.append(field)
|
68 |
+
|
69 |
+
elif "list" == dtype[0:4]:
|
70 |
+
list_fields.append(field)
|
71 |
+
|
72 |
+
else:
|
73 |
+
uncategorized_fields.append(field)
|
74 |
+
|
75 |
+
return features, \
|
76 |
+
numeric_fields, \
|
77 |
+
string_fields, \
|
78 |
+
bool_fields, \
|
79 |
+
list_fields, \
|
80 |
+
label_fields, \
|
81 |
+
categorical_fields, \
|
82 |
+
datetime_fields, \
|
83 |
+
uncategorized_fields
|
84 |
+
|
85 |
+
|
86 |
+
def load_dataset_and_metadata(dataset_name,
|
87 |
+
config=None,
|
88 |
+
streaming=True):
|
89 |
+
|
90 |
+
configs = get_dataset_config_names(dataset_name)
|
91 |
+
if config is None:
|
92 |
+
config = configs[0]
|
93 |
+
|
94 |
+
splits = get_dataset_split_names(dataset_name, config)
|
95 |
+
dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming)
|
96 |
+
head = pa.Table.from_pydict(dataset._head())
|
97 |
+
|
98 |
+
schema_dict = {field.name: str(field.type) for field in head.schema}
|
99 |
+
|
100 |
+
dataset_dict = {
|
101 |
+
"first_split_dataset": dataset,
|
102 |
+
"name": dataset_name,
|
103 |
+
"config": config,
|
104 |
+
"splits": splits,
|
105 |
+
"schema": schema_dict,
|
106 |
+
"head": head
|
107 |
+
}
|
108 |
+
|
109 |
+
return dataset_dict
|
110 |
+
|
111 |
+
|
112 |
+
def upload_project_to_atlas(dataset_dict,
|
113 |
+
project_name = None,
|
114 |
+
unique_id_field_name=None,
|
115 |
+
indexed_field = None,
|
116 |
+
modality=None,
|
117 |
+
organization_name=None):
|
118 |
+
|
119 |
+
if modality is None:
|
120 |
+
modality = "text"
|
121 |
+
|
122 |
+
if unique_id_field_name is None:
|
123 |
+
unique_id_field_name = "atlas_datum_id"
|
124 |
+
|
125 |
+
if project_name is None:
|
126 |
+
project_name = dataset_dict["name"].replace("/", "--")
|
127 |
+
|
128 |
+
desc = f"Config: {dataset_dict['config']}"
|
129 |
+
|
130 |
+
features, \
|
131 |
+
numeric_fields, \
|
132 |
+
string_fields, \
|
133 |
+
bool_fields, \
|
134 |
+
list_fields, \
|
135 |
+
label_fields, \
|
136 |
+
categorical_fields, \
|
137 |
+
datetime_fields, \
|
138 |
+
uncategorized_fields = get_datum_fields(dataset_dict)
|
139 |
+
|
140 |
+
|
141 |
+
# return longest string field
|
142 |
+
if indexed_field is None:
|
143 |
+
ex = dataset_dict["head"].take([0])
|
144 |
+
longest_len = 0
|
145 |
+
for field in string_fields:
|
146 |
+
if ex[field] and len(ex[field]) > longest_len:
|
147 |
+
indexed_field = field
|
148 |
+
longest_len = len(ex[field])
|
149 |
+
|
150 |
+
|
151 |
+
topic_label_field = None
|
152 |
+
if modality == "embedding":
|
153 |
+
topic_label_field = indexed_field
|
154 |
+
indexed_field = None
|
155 |
+
|
156 |
+
|
157 |
+
easy_fields = string_fields + bool_fields + list_fields + categorical_fields
|
158 |
+
|
159 |
+
proj = nomic.AtlasProject(name=project_name,
|
160 |
+
modality=modality,
|
161 |
+
unique_id_field=unique_id_field_name,
|
162 |
+
organization_name=organization_name,
|
163 |
+
description=desc,
|
164 |
+
reset_project_if_exists=True)
|
165 |
+
|
166 |
+
colorable_fields = ["split"]
|
167 |
+
|
168 |
+
batch_size = 1000
|
169 |
+
batched_texts = []
|
170 |
+
|
171 |
+
for split in dataset_dict["splits"]:
|
172 |
+
|
173 |
+
dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True)
|
174 |
+
|
175 |
+
for i, ex in tqdm(enumerate(dataset)):
|
176 |
+
if i % 10000 == 0:
|
177 |
+
time.sleep(2)
|
178 |
+
|
179 |
+
data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"}
|
180 |
+
|
181 |
+
for field in numeric_fields:
|
182 |
+
data_to_add[field] = ex[field]
|
183 |
+
|
184 |
+
for field in easy_fields:
|
185 |
+
val = ""
|
186 |
+
if ex[field]:
|
187 |
+
val = str(ex[field])
|
188 |
+
data_to_add[field] = val
|
189 |
+
|
190 |
+
for field in datetime_fields:
|
191 |
+
try:
|
192 |
+
data_to_add[field] = parse(ex[field], fuzzy=False)
|
193 |
+
except:
|
194 |
+
data_to_add[field] = None
|
195 |
+
|
196 |
+
for field in label_fields:
|
197 |
+
label_name = ""
|
198 |
+
if ex[field] is not None:
|
199 |
+
index = ex[field]
|
200 |
+
# NOTE: THIS MAY BREAK if -1 is ACTUALLY NO LABEL
|
201 |
+
if index != -1:
|
202 |
+
label_name = features[field].names[ex[field]]
|
203 |
+
data_to_add[field] = str(ex[field])
|
204 |
+
data_to_add[field + "_name"] = label_name
|
205 |
+
colorable_fields.add(field + "_name")
|
206 |
+
|
207 |
+
for field in list_fields:
|
208 |
+
list_str = ""
|
209 |
+
if ex[field]:
|
210 |
+
try:
|
211 |
+
list_str = str(ex[field])
|
212 |
+
except:
|
213 |
+
continue
|
214 |
+
data_to_add[field] = list_str
|
215 |
+
|
216 |
+
batched_texts.append(data_to_add)
|
217 |
+
|
218 |
+
if len(batched_texts) >= batch_size:
|
219 |
+
proj.add_text(batched_texts)
|
220 |
+
batched_texts = []
|
221 |
+
|
222 |
+
if len(batched_texts) > 0:
|
223 |
+
proj.add_text(batched_texts)
|
224 |
+
|
225 |
+
colorable_fields = colorable_fields + \
|
226 |
+
categorical_fields + label_fields + bool_fields + datetime_fields
|
227 |
+
|
228 |
+
projection = proj.create_index(name=project_name + " index",
|
229 |
+
indexed_field=indexed_field,
|
230 |
+
colorable_fields=colorable_fields,
|
231 |
+
topic_label_field = topic_label_field,
|
232 |
+
build_topic_model=True)
|
233 |
+
|
234 |
+
return projection.map_link
|
235 |
+
|
236 |
+
# Run test
|
237 |
+
if __name__ == "__main__":
|
238 |
+
dataset_name = "databricks/databricks-dolly-15k"
|
239 |
+
#dataset_name = "fka/awesome-chatgpt-prompts"
|
240 |
+
project_name = "huggingface_auto_upload_test-dolly-15k"
|
241 |
+
|
242 |
+
dataset_dict = load_dataset_and_metadata(dataset_name)
|
243 |
+
print(upload_project_to_atlas(dataset_dict, project_name=project_name))
|