maximevo commited on
Commit
86971d9
Β·
1 Parent(s): 838b068

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +498 -0
app.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import re
4
+ import json
5
+ import time
6
+ import pandas as pd
7
+ import labelbox
8
+
9
+ def validate_dataset_name(name):
10
+ """Validate the dataset name."""
11
+ # Check length
12
+ if len(name) > 256:
13
+ return "Dataset name should be limited to 256 characters."
14
+ # Check allowed characters
15
+ allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$')
16
+ if not allowed_characters_pattern.match(name):
17
+ return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.")
18
+ return None
19
+
20
+ def create_new_dataset_labelbox (new_dataset_name):
21
+ client = labelbox.Client(api_key=labelbox_api_key)
22
+ dataset_name = new_dataset_name
23
+ dataset = client.create_dataset(name=dataset_name)
24
+ dataset_id = dataset.uid
25
+ return dataset_id
26
+
27
+
28
+ def get_dataset_from_labelbox(labelbox_api_key):
29
+ client = labelbox.Client(api_key=labelbox_api_key)
30
+ datasets = client.get_datasets()
31
+ return datasets
32
+
33
+ def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key):
34
+ DOMAIN = f"https://{domain}"
35
+ TOKEN = f"Bearer {databricks_api_key}"
36
+
37
+ headers = {
38
+ "Authorization": TOKEN,
39
+ "Content-Type": "application/json",
40
+ }
41
+
42
+ # Destroy context
43
+ destroy_payload = {
44
+ "clusterId": cluster_id,
45
+ "contextId": context_id
46
+ }
47
+ destroy_response = requests.post(
48
+ f"{DOMAIN}/api/1.2/contexts/destroy",
49
+ headers=headers,
50
+ data=json.dumps(destroy_payload)
51
+ )
52
+
53
+ if destroy_response.status_code != 200:
54
+ raise ValueError("Failed to destroy context.")
55
+
56
+ def execute_databricks_query(query, cluster_id, domain, databricks_api_key):
57
+ DOMAIN = f"https://{domain}"
58
+ TOKEN = f"Bearer {databricks_api_key}"
59
+
60
+ headers = {
61
+ "Authorization": TOKEN,
62
+ "Content-Type": "application/json",
63
+ }
64
+
65
+ # Create context
66
+ context_payload = {
67
+ "clusterId": cluster_id,
68
+ "language": "sql"
69
+ }
70
+ context_response = requests.post(
71
+ f"{DOMAIN}/api/1.2/contexts/create",
72
+ headers=headers,
73
+ data=json.dumps(context_payload)
74
+ )
75
+ context_response_data = context_response.json()
76
+
77
+ if 'id' not in context_response_data:
78
+ raise ValueError("Failed to create context.")
79
+ context_id = context_response_data['id']
80
+
81
+ # Execute query
82
+ command_payload = {
83
+ "clusterId": cluster_id,
84
+ "contextId": context_id,
85
+ "language": "sql",
86
+ "command": query
87
+ }
88
+ command_response = requests.post(
89
+ f"{DOMAIN}/api/1.2/commands/execute",
90
+ headers=headers,
91
+ data=json.dumps(command_payload)
92
+ ).json()
93
+
94
+ if 'id' not in command_response:
95
+ raise ValueError("Failed to execute command.")
96
+ command_id = command_response['id']
97
+
98
+ # Wait for the command to complete
99
+ while True:
100
+ status_response = requests.get(
101
+ f"{DOMAIN}/api/1.2/commands/status",
102
+ headers=headers,
103
+ params={
104
+ "clusterId": cluster_id,
105
+ "contextId": context_id,
106
+ "commandId": command_id
107
+ }
108
+ ).json()
109
+
110
+ command_status = status_response.get("status")
111
+
112
+ if command_status == "Finished":
113
+ break
114
+ elif command_status in ["Error", "Cancelled"]:
115
+ raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}")
116
+ else:
117
+ time.sleep(1) # Wait for 5 seconds before checking again
118
+
119
+ # Convert the results into a pandas DataFrame
120
+ data = status_response.get('results', {}).get('data', [])
121
+ columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])]
122
+ df = pd.DataFrame(data, columns=columns)
123
+
124
+ destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key)
125
+
126
+ return df
127
+
128
+
129
+ st.title("Labelbox 🀝 Databricks")
130
+ st.header("Pipeline Creator", divider='rainbow')
131
+
132
+
133
+
134
+ def is_valid_url_or_uri(value):
135
+ """Check if the provided value is a valid URL or URI."""
136
+ # Check general URLs
137
+ url_pattern = re.compile(
138
+ r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
139
+ )
140
+
141
+ # Check general URIs including cloud storage URIs (like gs://, s3://, etc.)
142
+ uri_pattern = re.compile(
143
+ r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+'
144
+ )
145
+
146
+ return url_pattern.match(value) or uri_pattern.match(value)
147
+
148
+
149
+
150
+ is_preview = st.toggle('Run in Preview Mode', value=False)
151
+ if is_preview:
152
+ st.success('Running in Preview mode!', icon="βœ…")
153
+ else:
154
+ st.success('Running in Production mode!', icon="βœ…")
155
+
156
+ st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey')
157
+ title = st.text_input('Enter Databricks Domain (e.g., <instance>.<cloud>.databricks.com)', '')
158
+ databricks_api_key = st.text_input('Databricks API Key', type='password')
159
+ labelbox_api_key = st.text_input('Labelbox API Key', type='password')
160
+
161
+ # After Labelbox API key is entered
162
+ if labelbox_api_key:
163
+ # Fetching datasets
164
+ datasets = get_dataset_from_labelbox(labelbox_api_key)
165
+ create_new_dataset = st.toggle("Make me a new dataset", value=False)
166
+
167
+ if not create_new_dataset:
168
+ # The existing logic for selecting datasets goes here.
169
+ dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets}
170
+ selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys()))
171
+ dataset_id = dataset_name_to_id[selected_dataset_name]
172
+
173
+ else:
174
+ # If user toggles "make me a new dataset"
175
+ new_dataset_name = st.text_input("Enter the new dataset name:")
176
+
177
+ # Check if the name is valid
178
+ if new_dataset_name:
179
+ validation_message = validate_dataset_name(new_dataset_name)
180
+ if validation_message:
181
+ st.error(validation_message, icon="🚫")
182
+ else:
183
+ st.success(f"Valid dataset name! Dataset_id", icon="βœ…")
184
+ dataset_name = new_dataset_name
185
+
186
+ # Define the variables beforehand with default values (if not defined)
187
+ new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None
188
+ selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None
189
+
190
+ if new_dataset_name or selected_dataset_name:
191
+ # Handling various formats of input
192
+ formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https://
193
+ formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present
194
+
195
+ if formatted_title:
196
+ st.subheader("Select and existing cluster or make a new one", divider='grey', help="Jobs in preview mode will use all purpose compute clusters to help you itersate faster. Jobs in production mode will use job clusters to reduce DBUs consumed.")
197
+ DOMAIN = f"https://{formatted_title}"
198
+ TOKEN = f"Bearer {databricks_api_key}"
199
+
200
+ HEADERS = {
201
+ "Authorization": TOKEN,
202
+ "Content-Type": "application/json",
203
+ }
204
+
205
+ # Endpoint to list clusters
206
+ ENDPOINT = "/api/2.0/clusters/list"
207
+
208
+ try:
209
+ response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS)
210
+ response.raise_for_status()
211
+
212
+ # Include clusters with cluster_source "UI" or "API"
213
+ clusters = response.json().get("clusters", [])
214
+ cluster_dict = {
215
+ cluster["cluster_name"]: cluster["cluster_id"]
216
+ for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"]
217
+ }
218
+
219
+ # Display dropdown with cluster names
220
+ make_cluster = st.toggle('Make me a new cluster', value=False)
221
+ if make_cluster:
222
+ #make a cluster
223
+ st.write("Making a new cluster")
224
+ else:
225
+ if cluster_dict:
226
+ selected_cluster_name = st.selectbox(
227
+ 'Select a cluster to run on',
228
+ list(cluster_dict.keys()),
229
+ key='unique_key_for_cluster_selectbox',
230
+ index=None,
231
+ placeholder="Select a cluster..",
232
+ )
233
+ if selected_cluster_name:
234
+ cluster_id = cluster_dict[selected_cluster_name]
235
+ else:
236
+ st.write("No UI or API-based compute clusters found.")
237
+
238
+ except requests.RequestException as e:
239
+ st.write(f"Error communicating with Databricks API: {str(e)}")
240
+ except ValueError:
241
+ st.write("Received unexpected response from Databricks API.")
242
+
243
+ if selected_cluster_name and cluster_id:
244
+ # Check if the selected cluster is running
245
+ cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0]
246
+
247
+ # If the cluster is not running, start it
248
+ if cluster_state != "RUNNING":
249
+ with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."):
250
+ start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id})
251
+ start_response.raise_for_status()
252
+
253
+ # Poll until the cluster is up or until timeout
254
+ start_time = time.time()
255
+ timeout = 1200 # 20 minutes in seconds
256
+ while True:
257
+ cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json()
258
+ if "state" in cluster_response:
259
+ if cluster_response["state"] == "RUNNING":
260
+ break
261
+ elif cluster_response["state"] in ["TERMINATED", "ERROR"]:
262
+ st.write(f"Error starting cluster. Current state: {cluster_response['state']}")
263
+ break
264
+
265
+ if (time.time() - start_time) > timeout:
266
+ st.write("Timeout reached while starting the cluster.")
267
+ break
268
+
269
+ time.sleep(10) # Check every 10 seconds
270
+
271
+ st.success(f"{selected_cluster_name} is now running!", icon="πŸƒβ€β™‚οΈ")
272
+ else:
273
+ st.success(f"{selected_cluster_name} is already running!", icon="πŸƒβ€β™‚οΈ")
274
+
275
+
276
+ def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
277
+ """
278
+ Generate a cron expression based on the provided frequency and time.
279
+ """
280
+ if freq == "1 minute":
281
+ return "0 * * * * ?"
282
+ elif freq == "1 hour":
283
+ return f"0 {minute} * * * ?"
284
+ elif freq == "1 day":
285
+ return f"0 {minute} {hour} * * ?"
286
+ elif freq == "1 week":
287
+ if not day_of_week:
288
+ raise ValueError("Day of week not provided for weekly frequency.")
289
+ return f"0 {minute} {hour} ? * {day_of_week}"
290
+ elif freq == "1 month":
291
+ if not day_of_month:
292
+ raise ValueError("Day of month not provided for monthly frequency.")
293
+ return f"0 {minute} {hour} {day_of_month} * ?"
294
+ else:
295
+ raise ValueError("Invalid frequency provided")
296
+
297
+ # Streamlit UI
298
+ st.subheader("Run Frequency", divider='grey')
299
+
300
+ # Dropdown to select frequency
301
+ freq_options = ["1 minute", "1 hour", "1 day", "1 week", "1 month"]
302
+ selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..")
303
+
304
+ day_of_week = None
305
+ day_of_month = None
306
+
307
+ # If the frequency is hourly, daily, weekly, or monthly, ask for a specific time
308
+ if selected_freq != "1 minute":
309
+ col1, col2 = st.columns(2)
310
+ with col1:
311
+ hour = st.selectbox("Hour:", list(range(0, 24)))
312
+ with col2:
313
+ minute = st.selectbox("Minute:", list(range(0, 60)))
314
+
315
+ if selected_freq == "1 week":
316
+ days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"]
317
+ day_of_week = st.selectbox("Select day of the week:", days_options)
318
+
319
+ elif selected_freq == "1 month":
320
+ day_of_month = st.selectbox("Select day of the month:", list(range(1, 32)))
321
+
322
+ else:
323
+ hour, minute = 0, 0
324
+
325
+ # Generate the cron expression
326
+ frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month)
327
+
328
+ def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
329
+ """
330
+ Generate a human-readable message for the scheduling.
331
+ """
332
+ if freq == "1 minute":
333
+ return "Job will run every minute."
334
+ elif freq == "1 hour":
335
+ return f"Job will run once an hour at minute {minute}."
336
+ elif freq == "1 day":
337
+ return f"Job will run daily at {hour:02}:{minute:02}."
338
+ elif freq == "1 week":
339
+ if not day_of_week:
340
+ raise ValueError("Day of week not provided for weekly frequency.")
341
+ return f"Job will run every {day_of_week} at {hour:02}:{minute:02}."
342
+ elif freq == "1 month":
343
+ if not day_of_month:
344
+ raise ValueError("Day of month not provided for monthly frequency.")
345
+ return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}."
346
+ else:
347
+ raise ValueError("Invalid frequency provided")
348
+
349
+ # Generate the human-readable message
350
+ readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month)
351
+
352
+ if frequency:
353
+ st.success(readable_msg, icon="πŸ“…")
354
+
355
+ st.subheader("Select a table", divider="grey")
356
+
357
+ with st.spinner('Querying Databricks...'):
358
+ query = "SHOW DATABASES;"
359
+ result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
360
+
361
+ # Extract the databaseName values from the DataFrame
362
+ database_names = result_data['databaseName'].tolist()
363
+
364
+ # Create a dropdown with the database names
365
+ selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..")
366
+
367
+ if selected_database:
368
+ with st.spinner('Querying Databricks...'):
369
+ query = f"SHOW TABLES IN {selected_database};"
370
+ result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
371
+
372
+ # Extract the tableName values from the DataFrame
373
+ table_names = result_data['tableName'].tolist()
374
+
375
+ # Create a dropdown with the database names
376
+ selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..")
377
+
378
+ if selected_table:
379
+ with st.spinner('Querying Databricks...'):
380
+ query = f"SHOW COLUMNS IN {selected_database}.{selected_table};"
381
+ result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
382
+ column_names = result_data['col_name'].tolist()
383
+
384
+ st.subheader("Map table schema to Labelbox schema", divider="grey")
385
+ # Your existing code to handle schema mapping...
386
+
387
+ # Fetch the first 5 rows of the selected table
388
+ with st.spinner('Fetching first 5 rows of the selected table...'):
389
+ query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;"
390
+ table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
391
+
392
+ # Display the sample data in the Streamlit UI
393
+ st.write(table_sample_data)
394
+
395
+
396
+ # Define two columns for side-by-side selectboxes
397
+ col1, col2 = st.columns(2)
398
+
399
+ with col1:
400
+ selected_row_data = st.selectbox(
401
+ "row_data (required):",
402
+ column_names,
403
+ index=None,
404
+ placeholder="Select a column..",
405
+ help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox."
406
+ )
407
+
408
+ with col2:
409
+ selected_global_key = st.selectbox(
410
+ "global_key (optional):",
411
+ column_names,
412
+ index=None,
413
+ placeholder="Select a column..",
414
+ help="Select the column that contains the global key. If not provided, a new key will be generated for you."
415
+ )
416
+
417
+ # Fetch a single row from the selected table
418
+ query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;"
419
+ result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key)
420
+
421
+ if selected_row_data:
422
+ # Extract the value from the selected row_data column
423
+ sample_row_data_value = result_sample[selected_row_data].iloc[0]
424
+
425
+ # Validate the extracted value
426
+ if is_valid_url_or_uri(sample_row_data_value):
427
+ st.success(f"Sample URI/URL from selected row data column: {sample_row_data_value}", icon="βœ…")
428
+ dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id
429
+ # Mode
430
+ mode = "preview" if is_preview else "production"
431
+
432
+ # Databricks instance and API key
433
+ databricks_instance = formatted_title
434
+ databricks_api_key = databricks_api_key
435
+
436
+ # Dataset ID and New Dataset
437
+ new_dataset = 1 if create_new_dataset else 0
438
+ dataset_id = dataset_id
439
+
440
+ # Table Path
441
+ table_path = f"{selected_database}.{selected_table}"
442
+ # Frequency
443
+ frequency = frequency
444
+
445
+ # Cluster ID and New Cluster
446
+ new_cluster = 1 if make_cluster else 0
447
+ cluster_id = cluster_id if not make_cluster else ""
448
+
449
+ # Schema Map
450
+ row_data_input = selected_row_data
451
+ global_key_input = selected_global_key
452
+ schema_map_dict = {'row_data': row_data_input}
453
+ if global_key_input:
454
+ schema_map_dict['global_key'] = global_key_input
455
+
456
+ # Convert the dict to a stringified JSON
457
+ schema_map_str = json.dumps(schema_map_dict)
458
+
459
+
460
+ data = {
461
+ "mode": mode,
462
+ "databricks_instance": databricks_instance,
463
+ "databricks_api_key": databricks_api_key,
464
+ "new_dataset": new_dataset,
465
+ "dataset_id": dataset_id,
466
+ "table_path": table_path,
467
+ "labelbox_api_key": labelbox_api_key,
468
+ "frequency": frequency,
469
+ "new_cluster": new_cluster,
470
+ "cluster_id": cluster_id,
471
+ "schema_map": schema_map_str
472
+ }
473
+
474
+
475
+ if st.button("Deploy Pipeline!", type="primary"):
476
+ # Ensure all fields are filled out
477
+ required_fields = [
478
+ mode, databricks_instance, databricks_api_key, new_dataset, dataset_id,
479
+ table_path, labelbox_api_key, frequency, new_cluster, cluster_id, schema_map_str
480
+ ]
481
+
482
+
483
+ # Sending a POST request to the Flask app endpoint
484
+ with st.spinner("Deploying pipeline..."):
485
+ response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data)
486
+
487
+ # Check if request was successful
488
+ if response.status_code == 200:
489
+ # Display the response using Streamlit
490
+ st.balloons()
491
+ st.success("Pipeline deployed successfully!", icon="πŸš€")
492
+ st.json(response.json())
493
+ else:
494
+ st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="🚫")
495
+
496
+ else:
497
+ st.error(f"row_data '{sample_row_data_value}' is not a valid URI or URL. Please select a different column.", icon="🚫")
498
+