Richard Guo commited on
Commit
9a64205
·
1 Parent(s): 1bacad4

working auto dataset upload

Browse files
Files changed (2) hide show
  1. app.py +5 -3
  2. build_map.py +243 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Request
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.templating import Jinja2Templates
4
  from typing import Optional
@@ -15,15 +15,17 @@ templates = Jinja2Templates(directory="templates")
15
  # Create a Pydantic model for the form data
16
  class DatasetForm(BaseModel):
17
  dataset_name: str
18
-
19
 
20
 
 
 
 
 
21
  @app.get("/", response_class=HTMLResponse)
22
  async def read_form(request: Request):
23
  # Render the form.html template
24
  return templates.TemplateResponse("form.html", {"request": request})
25
 
26
-
27
  @app.post("/submit_form")
28
  async def form_post(form_data: DatasetForm):
29
  # Do something with form_data
 
1
+ from fastapi import FastAPI, Request, WebSocket
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.templating import Jinja2Templates
4
  from typing import Optional
 
15
  # Create a Pydantic model for the form data
16
  class DatasetForm(BaseModel):
17
  dataset_name: str
 
18
 
19
 
20
+
21
+ def long_running_function():
22
+ pass
23
+
24
  @app.get("/", response_class=HTMLResponse)
25
  async def read_form(request: Request):
26
  # Render the form.html template
27
  return templates.TemplateResponse("form.html", {"request": request})
28
 
 
29
  @app.post("/submit_form")
30
  async def form_post(form_data: DatasetForm):
31
  # Do something with form_data
build_map.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nomic
2
+ import pandas as pd
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset, \
5
+ get_dataset_split_names, \
6
+ get_dataset_config_names, \
7
+ ClassLabel, utils
8
+
9
+ utils.logging.set_verbosity_error()
10
+ import pyarrow as pa
11
+ from dateutil.parser import parse
12
+ import time
13
+
14
+
15
+ def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20):
16
+ # take a sample of points
17
+ dataset = dataset_dict["first_split_dataset"]
18
+ sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples))
19
+ features = dataset.features
20
+
21
+ numeric_fields = []
22
+ string_fields = []
23
+ bool_fields = []
24
+ list_fields = []
25
+ label_fields = []
26
+ categorical_fields = []
27
+ datetime_fields = []
28
+ uncategorized_fields = []
29
+
30
+ if unique_cutoff < 1:
31
+ unique_cutoff = unique_cutoff*len(sample)
32
+
33
+ for field, dtype in dataset_dict["schema"].items():
34
+ try:
35
+ num_unique = sample[field].nunique()
36
+ except:
37
+ num_unique = len(sample)
38
+
39
+ if dtype == "string":
40
+ if num_unique < unique_cutoff:
41
+ categorical_fields.append(field)
42
+ else:
43
+ is_datetime = True
44
+ for row in sample:
45
+ try:
46
+ parse(row[field], fuzzy=False)
47
+ except:
48
+ is_datetime = False
49
+ break
50
+ if is_datetime:
51
+ datetime_fields.append(field)
52
+ else:
53
+ string_fields.append(field)
54
+
55
+ elif dtype in ("float"):
56
+ numeric_fields.append(field)
57
+
58
+ elif dtype in ("int64", "int32", "int16", "int8"):
59
+ if features is not None and field in features and isinstance(features[field], ClassLabel):
60
+ label_fields.append(field)
61
+ elif num_unique < unique_cutoff:
62
+ categorical_fields.append(field)
63
+ else:
64
+ numeric_fields.append(field)
65
+
66
+ elif dtype == "bool":
67
+ bool_fields.append(field)
68
+
69
+ elif "list" == dtype[0:4]:
70
+ list_fields.append(field)
71
+
72
+ else:
73
+ uncategorized_fields.append(field)
74
+
75
+ return features, \
76
+ numeric_fields, \
77
+ string_fields, \
78
+ bool_fields, \
79
+ list_fields, \
80
+ label_fields, \
81
+ categorical_fields, \
82
+ datetime_fields, \
83
+ uncategorized_fields
84
+
85
+
86
+ def load_dataset_and_metadata(dataset_name,
87
+ config=None,
88
+ streaming=True):
89
+
90
+ configs = get_dataset_config_names(dataset_name)
91
+ if config is None:
92
+ config = configs[0]
93
+
94
+ splits = get_dataset_split_names(dataset_name, config)
95
+ dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming)
96
+ head = pa.Table.from_pydict(dataset._head())
97
+
98
+ schema_dict = {field.name: str(field.type) for field in head.schema}
99
+
100
+ dataset_dict = {
101
+ "first_split_dataset": dataset,
102
+ "name": dataset_name,
103
+ "config": config,
104
+ "splits": splits,
105
+ "schema": schema_dict,
106
+ "head": head
107
+ }
108
+
109
+ return dataset_dict
110
+
111
+
112
+ def upload_project_to_atlas(dataset_dict,
113
+ project_name = None,
114
+ unique_id_field_name=None,
115
+ indexed_field = None,
116
+ modality=None,
117
+ organization_name=None):
118
+
119
+ if modality is None:
120
+ modality = "text"
121
+
122
+ if unique_id_field_name is None:
123
+ unique_id_field_name = "atlas_datum_id"
124
+
125
+ if project_name is None:
126
+ project_name = dataset_dict["name"].replace("/", "--")
127
+
128
+ desc = f"Config: {dataset_dict['config']}"
129
+
130
+ features, \
131
+ numeric_fields, \
132
+ string_fields, \
133
+ bool_fields, \
134
+ list_fields, \
135
+ label_fields, \
136
+ categorical_fields, \
137
+ datetime_fields, \
138
+ uncategorized_fields = get_datum_fields(dataset_dict)
139
+
140
+
141
+ # return longest string field
142
+ if indexed_field is None:
143
+ ex = dataset_dict["head"].take([0])
144
+ longest_len = 0
145
+ for field in string_fields:
146
+ if ex[field] and len(ex[field]) > longest_len:
147
+ indexed_field = field
148
+ longest_len = len(ex[field])
149
+
150
+
151
+ topic_label_field = None
152
+ if modality == "embedding":
153
+ topic_label_field = indexed_field
154
+ indexed_field = None
155
+
156
+
157
+ easy_fields = string_fields + bool_fields + list_fields + categorical_fields
158
+
159
+ proj = nomic.AtlasProject(name=project_name,
160
+ modality=modality,
161
+ unique_id_field=unique_id_field_name,
162
+ organization_name=organization_name,
163
+ description=desc,
164
+ reset_project_if_exists=True)
165
+
166
+ colorable_fields = ["split"]
167
+
168
+ batch_size = 1000
169
+ batched_texts = []
170
+
171
+ for split in dataset_dict["splits"]:
172
+
173
+ dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True)
174
+
175
+ for i, ex in tqdm(enumerate(dataset)):
176
+ if i % 10000 == 0:
177
+ time.sleep(2)
178
+
179
+ data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"}
180
+
181
+ for field in numeric_fields:
182
+ data_to_add[field] = ex[field]
183
+
184
+ for field in easy_fields:
185
+ val = ""
186
+ if ex[field]:
187
+ val = str(ex[field])
188
+ data_to_add[field] = val
189
+
190
+ for field in datetime_fields:
191
+ try:
192
+ data_to_add[field] = parse(ex[field], fuzzy=False)
193
+ except:
194
+ data_to_add[field] = None
195
+
196
+ for field in label_fields:
197
+ label_name = ""
198
+ if ex[field] is not None:
199
+ index = ex[field]
200
+ # NOTE: THIS MAY BREAK if -1 is ACTUALLY NO LABEL
201
+ if index != -1:
202
+ label_name = features[field].names[ex[field]]
203
+ data_to_add[field] = str(ex[field])
204
+ data_to_add[field + "_name"] = label_name
205
+ colorable_fields.add(field + "_name")
206
+
207
+ for field in list_fields:
208
+ list_str = ""
209
+ if ex[field]:
210
+ try:
211
+ list_str = str(ex[field])
212
+ except:
213
+ continue
214
+ data_to_add[field] = list_str
215
+
216
+ batched_texts.append(data_to_add)
217
+
218
+ if len(batched_texts) >= batch_size:
219
+ proj.add_text(batched_texts)
220
+ batched_texts = []
221
+
222
+ if len(batched_texts) > 0:
223
+ proj.add_text(batched_texts)
224
+
225
+ colorable_fields = colorable_fields + \
226
+ categorical_fields + label_fields + bool_fields + datetime_fields
227
+
228
+ projection = proj.create_index(name=project_name + " index",
229
+ indexed_field=indexed_field,
230
+ colorable_fields=colorable_fields,
231
+ topic_label_field = topic_label_field,
232
+ build_topic_model=True)
233
+
234
+ return projection.map_link
235
+
236
+ # Run test
237
+ if __name__ == "__main__":
238
+ dataset_name = "databricks/databricks-dolly-15k"
239
+ #dataset_name = "fka/awesome-chatgpt-prompts"
240
+ project_name = "huggingface_auto_upload_test-dolly-15k"
241
+
242
+ dataset_dict = load_dataset_and_metadata(dataset_name)
243
+ print(upload_project_to_atlas(dataset_dict, project_name=project_name))