databricks_upload / preview_mode_server.py
maximevo's picture
Create preview_mode_server.py
70a2533
raw
history blame
10.8 kB
from flask import Flask, request, jsonify
import threading
import requests
import json
app = Flask(__name__)
def create_databricks_job(data):
mode = data.get('mode')
databricks_instance = data.get('databricks_instance')
databricks_api_key = data.get('databricks_api_key')
new_dataset = data.get('new_dataset')
dataset_id = data.get('dataset_id')
table_path = data.get('table_path')
labelbox_api_key = data.get('labelbox_api_key')
frequency = data.get('frequency')
new_cluster = data.get('new_cluster')
cluster_id = data.get('cluster_id')
schema_map = data.get('schema_map')
# Define the authentication headers
headers = {
"Authorization": f"Bearer {databricks_api_key}",
"Content-Type": "application/json",
}
# ----- CLUSTER CREATION LOGIC -----
def create_all_purpose_cluster(databricks_instance):
url = f"https://{databricks_instance}/api/2.0/clusters/create"
cluster_payload = {
"autoscale": {
"min_workers": 1,
"max_workers": 10
},
"cluster_name": "Labelbox Worker",
"spark_version": "11.3.x-scala2.12",
"gcp_attributes": {
"use_preemptible_executors": False,
"availability": "PREEMPTIBLE_WITH_FALLBACK_GCP",
"zone_id": "HA"
},
"node_type_id": "n2-highmem-4",
"driver_node_type_id": "n2-highmem-4",
"ssh_public_keys": [],
"custom_tags": {},
"cluster_log_conf": {
"dbfs": {
"destination": "dbfs:/cluster-logs"
}
},
"spark_env_vars": {},
"autotermination_minutes": 60,
"enable_elastic_disk": False,
"init_scripts": [],
"enable_local_disk_encryption": False,
"runtime_engine": "STANDARD"
}
response = requests.post(url, data=json.dumps(cluster_payload), headers=headers)
if response.status_code == 200:
return response.json()['cluster_id']
else:
raise Exception(f"Failed to create all-purpose cluster. Error: {response.text}")
# ----- PREVIEW MODE LOGIC -----
def create_preview(dataset_id, table_path, labelbox_api_key, frequency, cluster_id):
# ----- JOB SCHEDULING LOGIC -----
if frequency == "continuous":
schedule_block = {
"continuous": {
"pause_status": "UNPAUSED"
}
}
else:
schedule_block = {
"schedule": {
"quartz_cron_expression": frequency,
"timezone_id": "UTC",
"pause_status": "UNPAUSED"
}
}
# ----- JOB DEFINITION -----
# Define the parameters and structure of the job to be created in Databricks
payload = {
"name": "PREVIEW_upload_to_labelbox",
"email_notifications": {"no_alert_for_skipped_runs": False},
"webhook_notifications": {},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"existing_cluster_id": cluster_id, # Move this inside the task
"task_key": "PREVIEW_upload_to_labelbox",
"run_if": "ALL_SUCCESS",
"notebook_task": {
"notebook_path": "notebooks/databricks_pipeline_creator/preview_upload_to_labelbox",
"base_parameters": {
"dataset_id": dataset_id,
"table_path": table_path,
"labelbox_api_key": labelbox_api_key,
"schema_map": schema_map
},
"source": "GIT"
},
"libraries": [
{"pypi": {"package": "labelspark"}},
{"pypi": {"package": "labelbox==3.49.1"}},
{"pypi": {"package": "numpy==1.25"}},
{"pypi": {"package": "opencv-python==4.8.0.74"}}
],
"timeout_seconds": 0,
"email_notifications": {},
"notification_settings": {
"no_alert_for_skipped_runs": False,
"no_alert_for_canceled_runs": False,
"alert_on_last_attempt": False
}
}
],
"git_source": {
"git_url": "https://github.com/Labelbox/labelspark.git",
"git_provider": "gitHub",
"git_branch": "master"
},
"format": "MULTI_TASK"
}
# Merge the scheduling configuration into the main job payload
payload.update(schedule_block)
return payload
# ----- PRODUCTION MODE LOGIC -----
def create_production(dataset_id, table_path, labelbox_api_key, frequency):
# ----- JOB SCHEDULING LOGIC -----
# If the job needs to run continuously, use the "continuous" block
# Else, use the "schedule" block with the specified cron frequency
if frequency == "continuous":
schedule_block = {
"continuous": {
"pause_status": "UNPAUSED"
}
}
else:
schedule_block = {
"schedule": {
"quartz_cron_expression": frequency,
"timezone_id": "UTC",
"pause_status": "UNPAUSED"
}
}
# ----- JOB DEFINITION -----
# Define the parameters and structure of the job to be created in Databricks
payload = {
"name": "upload_to_labelbox",
"email_notifications": {"no_alert_for_skipped_runs": False},
"webhook_notifications": {},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"task_key": "upload_to_labelbox",
"run_if": "ALL_SUCCESS",
"notebook_task": {
"notebook_path": "notebooks/databricks_pipeline_creator/upload_to_labelbox",
"base_parameters": {
"dataset_id": dataset_id,
"table_path": table_path,
"labelbox_api_key": labelbox_api_key,
"schema_map": schema_map
},
"source": "GIT"
},
"job_cluster_key": "Job_cluster",
"libraries": [
{"pypi": {"package": "labelspark"}},
{"pypi": {"package": "labelbox==3.49.1"}},
{"pypi": {"package": "numpy==1.25"}},
{"pypi": {"package": "opencv-python==4.8.0.74"}}
],
"timeout_seconds": 0,
"email_notifications": {},
"notification_settings": {
"no_alert_for_skipped_runs": False,
"no_alert_for_canceled_runs": False,
"alert_on_last_attempt": False
}
}
],
"job_clusters": [
{
"job_cluster_key": "Job_cluster",
"new_cluster": {
"cluster_name": "",
"spark_version": "13.3.x-scala2.12",
"gcp_attributes": {
"use_preemptible_executors": False,
"availability": "ON_DEMAND_GCP",
"zone_id": "HA"
},
"node_type_id": "n2-highmem-4",
"enable_elastic_disk": True,
"data_security_mode": "SINGLE_USER",
"runtime_engine": "STANDARD",
"autoscale": {
"min_workers": 1,
"max_workers": 10
}
}
}
],
"git_source": {
"git_url": "https://github.com/Labelbox/labelspark.git",
"git_provider": "gitHub",
"git_branch": "master"
},
"format": "MULTI_TASK"
}
# Merge the scheduling configuration into the main job payload
payload.update(schedule_block)
return payload
# if cluster_id blank or null:
if new_cluster == True:
cluster_id = create_all_purpose_cluster(databricks_instance)
print(f"Created all-purpose cluster with ID: {cluster_id}")
else:
print(f"Using existing cluster with ID: {cluster_id}")
if mode == "preview":
payload = create_preview(dataset_id, table_path, labelbox_api_key, frequency, cluster_id)
elif mode == "production":
payload = create_production(dataset_id, table_path, labelbox_api_key, frequency)
else:
return f"Invalid mode: {mode}"
# ----- JOB CREATION -----
# Formulate the endpoint URL for the Databricks REST API job creation
url = f"https://{databricks_instance}/api/2.0/jobs/create"
# Send the POST request to Databricks to create the job
response = requests.post(url, data=json.dumps(payload), headers=headers)
# ----- RESPONSE HANDLING -----
if response.status_code == 200:
return f"Job created successfully. {response.text}"
else:
return f"Failed to create job. Error: {response.text}"
@app.route('/create-databricks-job', methods=['POST'])
def api_create_databricks_job():
data = request.get_json()
result = create_databricks_job(data)
return jsonify({"message": result})
def run():
app.run(port=5000)
threading.Thread(target=run).start()