import streamlit as st
import requests
import re
import json
import time
import pandas as pd
import labelbox

@st.cache_data(show_spinner=True)
def fetch_databases(cluster_id, formatted_title, databricks_api_key):
    query = "SHOW DATABASES;"
    return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)

# Cached function to fetch tables
@st.cache_data(show_spinner=True)
def fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key):
    query = f"SHOW TABLES IN {selected_database};"
    return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)

# Cached function to fetch columns
@st.cache_data(show_spinner=True)
def fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key):
    query = f"SHOW COLUMNS IN {selected_database}.{selected_table};"
    return execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)

def validate_dataset_name(name):
    """Validate the dataset name."""
    # Check length
    if len(name) > 256:
        return "Dataset name should be limited to 256 characters."
    # Check allowed characters
    allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$')
    if not allowed_characters_pattern.match(name):
        return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.")
    return None

def create_new_dataset_labelbox (new_dataset_name):
    client = labelbox.Client(api_key=labelbox_api_key)
    dataset_name = new_dataset_name
    dataset = client.create_dataset(name=dataset_name)
    dataset_id = dataset.uid
    return dataset_id


def get_dataset_from_labelbox(labelbox_api_key):
    client = labelbox.Client(api_key=labelbox_api_key)
    datasets = client.get_datasets()
    return datasets

def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key):
    DOMAIN = f"https://{domain}"
    TOKEN = f"Bearer {databricks_api_key}"

    headers = {
        "Authorization": TOKEN,
        "Content-Type": "application/json",
    }

    # Destroy context
    destroy_payload = {
        "clusterId": cluster_id,
        "contextId": context_id
    }
    destroy_response = requests.post(
        f"{DOMAIN}/api/1.2/contexts/destroy",
        headers=headers,
        data=json.dumps(destroy_payload)
    )

    if destroy_response.status_code != 200:
        raise ValueError("Failed to destroy context.")
    
def execute_databricks_query(query, cluster_id, domain, databricks_api_key):
    DOMAIN = f"https://{domain}"
    TOKEN = f"Bearer {databricks_api_key}"

    headers = {
        "Authorization": TOKEN,
        "Content-Type": "application/json",
    }

    # Create context
    context_payload = {
        "clusterId": cluster_id,
        "language": "sql"
    }
    context_response = requests.post(
        f"{DOMAIN}/api/1.2/contexts/create",
        headers=headers,
        data=json.dumps(context_payload)
    )
    context_response_data = context_response.json()

    if 'id' not in context_response_data:
        raise ValueError("Failed to create context.")
    context_id = context_response_data['id']

    # Execute query
    command_payload = {
        "clusterId": cluster_id,
        "contextId": context_id,
        "language": "sql",
        "command": query
    }
    command_response = requests.post(
        f"{DOMAIN}/api/1.2/commands/execute",
        headers=headers,
        data=json.dumps(command_payload)
    ).json()

    if 'id' not in command_response:
        raise ValueError("Failed to execute command.")
    command_id = command_response['id']

    # Wait for the command to complete
    while True:
        status_response = requests.get(
            f"{DOMAIN}/api/1.2/commands/status",
            headers=headers,
            params={
                "clusterId": cluster_id,
                "contextId": context_id,
                "commandId": command_id
            }
        ).json()

        command_status = status_response.get("status")

        if command_status == "Finished":
            break
        elif command_status in ["Error", "Cancelled"]:
            raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}")
        else:
            time.sleep(1)  # Wait for 5 seconds before checking again

    # Convert the results into a pandas DataFrame
    data = status_response.get('results', {}).get('data', [])
    columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])]
    df = pd.DataFrame(data, columns=columns)

    destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key)
    
    return df


st.title("Labelbox 🤝 Databricks")
st.header("Pipeline Creator", divider='rainbow')



def is_valid_url_or_uri(value):
    """Check if the provided value is a valid URL or URI."""
    # Check general URLs
    url_pattern = re.compile(
        r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
    )
    
    # Check general URIs including cloud storage URIs (like gs://, s3://, etc.)
    uri_pattern = re.compile(
        r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+'
    )
    
    return url_pattern.match(value) or uri_pattern.match(value)



is_preview = st.toggle('Run in Preview Mode', value=False)
if is_preview:
    st.success('Running in Preview mode!', icon="✅")
else:
    st.success('Running in Production mode!', icon="✅")

st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey')
#cloud = "GCP"
cloud = st.selectbox('Which cloud environment does your Databricks Workspace run in?', ['AWS', 'Azure', 'GCP'], index=None)
title = st.text_input('Enter Databricks Domain (e.g., <instance>.<cloud>.databricks.com)', '')
databricks_api_key = st.text_input('Databricks API Key', type='password')
labelbox_api_key = st.text_input('Labelbox API Key', type='password')

# After Labelbox API key is entered
if labelbox_api_key:
    # Fetching datasets
    datasets = get_dataset_from_labelbox(labelbox_api_key)
    create_new_dataset = st.toggle("Make me a new dataset", value=False)

    if not create_new_dataset:
        # The existing logic for selecting datasets goes here.
        dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets}
        selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys()))
        dataset_id = dataset_name_to_id[selected_dataset_name]

    else:
        # If user toggles "make me a new dataset"
        new_dataset_name = st.text_input("Enter the new dataset name:")

        # Check if the name is valid
        if new_dataset_name:
            validation_message = validate_dataset_name(new_dataset_name)
            if validation_message:
                st.error(validation_message, icon="🚫")
            else:
                st.success(f"Valid dataset name! Dataset_id", icon="✅")
                dataset_name = new_dataset_name

# Define the variables beforehand with default values (if not defined)
new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None
selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None

if new_dataset_name or selected_dataset_name:
    # Handling various formats of input
    formatted_title = re.sub(r'^https?://', '', title)  # Remove http:// or https://
    formatted_title = re.sub(r'/$', '', formatted_title)  # Remove trailing slash if present

    if formatted_title:
        st.subheader("Select an existing cluster", divider='grey', help="Jobs will use job clusters to reduce DBUs consumed.")
        DOMAIN = f"https://{formatted_title}"
        TOKEN = f"Bearer {databricks_api_key}"

        HEADERS = {
            "Authorization": TOKEN,
            "Content-Type": "application/json",
        }

        # Endpoint to list clusters
        ENDPOINT = "/api/2.0/clusters/list"

        try:
            response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS)
            response.raise_for_status()

            # Include clusters with cluster_source "UI" or "API"
            clusters = response.json().get("clusters", [])
            cluster_dict = {
                cluster["cluster_name"]: cluster["cluster_id"]
                for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"]
            }

            # Display dropdown with cluster names
            
            if cluster_dict:
                selected_cluster_name = st.selectbox(
                    'Select a cluster to run on',
                    list(cluster_dict.keys()),
                    key='unique_key_for_cluster_selectbox',
                    index=None,
                    placeholder="Select a cluster..",
                )
                if selected_cluster_name:
                    cluster_id = cluster_dict[selected_cluster_name]

        except requests.RequestException as e:
            st.write(f"Error communicating with Databricks API: {str(e)}")
        except ValueError:
            st.write("Received unexpected response from Databricks API.")

        if selected_cluster_name and cluster_id:
            # Check if the selected cluster is running
            cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0]

            # If the cluster is not running, start it
            if cluster_state != "RUNNING":
                with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."):
                    start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id})
                    start_response.raise_for_status()

                    # Poll until the cluster is up or until timeout
                    start_time = time.time()
                    timeout = 1200  # 20 minutes in seconds
                    while True:
                        cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json()
                        if "state" in cluster_response:
                            if cluster_response["state"] == "RUNNING":
                                break
                            elif cluster_response["state"] in ["TERMINATED", "ERROR"]:
                                st.write(f"Error starting cluster. Current state: {cluster_response['state']}")
                                break

                        if (time.time() - start_time) > timeout:
                            st.write("Timeout reached while starting the cluster.")
                            break

                        time.sleep(10)  # Check every 10 seconds

                st.success(f"{selected_cluster_name} is now running!", icon="🏃‍♂️")
            else:
                st.success(f"{selected_cluster_name} is already running!", icon="🏃‍♂️")


            def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
                """
                Generate a cron expression based on the provided frequency and time.
                """
                if freq == "1 minute":
                    return "0 * * * * ?"
                elif freq == "1 hour":
                    return f"0 {minute} * * * ?"
                elif freq == "1 day":
                    return f"0 {minute} {hour} * * ?"
                elif freq == "1 week":
                    if not day_of_week:
                        raise ValueError("Day of week not provided for weekly frequency.")
                    return f"0 {minute} {hour} ? * {day_of_week}"
                elif freq == "1 month":
                    if not day_of_month:
                        raise ValueError("Day of month not provided for monthly frequency.")
                    return f"0 {minute} {hour} {day_of_month} * ?"
                else:
                    raise ValueError("Invalid frequency provided")

            # Streamlit UI
            st.subheader("Run Frequency", divider='grey')

            # Dropdown to select frequency
            freq_options = ["1 day", "1 week", "1 month"]
            selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..")

            day_of_week = None
            day_of_month = None

            # If the frequency is hourly, daily, weekly, or monthly, ask for a specific time
            if selected_freq != "1 minute":
               
                if selected_freq == "1 week":
                    days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"]
                    day_of_week = st.selectbox("Select day of the week:", days_options)

                elif selected_freq == "1 month":
                    day_of_month = st.selectbox("Select day of the month:", list(range(1, 32)))

                col1, col2 = st.columns(2)    
                with col1:
                    hour = st.selectbox("Hour:", list(range(0, 24)))
                with col2:
                    minute = st.selectbox("Minute:", list(range(0, 60)))

            else:
                hour, minute = 0, 0

            # Generate the cron expression
            frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month)

            # Assumed DBU consumption rate for a 32GB, 4-core node per hour
            X = 1  # Replace this with the actual rate from Databricks' pricing or documentation

            # Calculate DBU consumption for a single run
            min_dbu_single_run = (X/6) * (1 + 10)  # Assuming maximum scaling to 10 workers
            max_dbu_single_run = (2*X/3) * (1 + 10)

            # Estimate monthly DBU consumption based on frequency
            if freq_options == "1 day":
                min_dbu_monthly = 30 * min_dbu_single_run
                max_dbu_monthly = 30 * max_dbu_single_run
            elif freq_options == "1 week":
                min_dbu_monthly = 4 * min_dbu_single_run
                max_dbu_monthly = 4 * max_dbu_single_run
            else:  # Monthly
                min_dbu_monthly = min_dbu_single_run
                max_dbu_monthly = max_dbu_single_run

            # Calculate runs per month
            if selected_freq == "1 day":
                runs_per_month = 30
            elif selected_freq == "1 week":
                runs_per_month = 4
            else:  # "1 month"
                runs_per_month = 1

            # Calculate estimated DBU consumption per month
            min_dbu_monthly = runs_per_month * min_dbu_single_run
            max_dbu_monthly = runs_per_month * max_dbu_single_run

            def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None):
                """
                Generate a human-readable message for the scheduling.
                """
                if freq == "1 minute":
                    return "Job will run every minute."
                elif freq == "1 hour":
                    return f"Job will run once an hour at minute {minute}."
                elif freq == "1 day":
                    return f"Job will run daily at {hour:02}:{minute:02}."
                elif freq == "1 week":
                    if not day_of_week:
                        raise ValueError("Day of week not provided for weekly frequency.")
                    return f"Job will run every {day_of_week} at {hour:02}:{minute:02}."
                elif freq == "1 month":
                    if not day_of_month:
                        raise ValueError("Day of month not provided for monthly frequency.")
                    return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}."
                else:
                    raise ValueError("Invalid frequency provided")

            # Generate the human-readable message
            readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month)

            # Main code block
            if frequency:
                st.success(readable_msg, icon="📅")
                # Display the estimated DBU consumption to the user
                st.warning(f"Estimated DBU Consumption:\n- For a single run: {min_dbu_single_run:.2f} to {max_dbu_single_run:.2f} DBUs\n- Monthly (based on {runs_per_month} runs): {min_dbu_monthly:.2f} to {max_dbu_monthly:.2f} DBUs")

                # Disclaimer
                st.info("Disclaimer: This is only an estimation. Always monitor the job in Databricks to assess actual DBU consumption.")

                st.subheader("Select a table", divider="grey")

                # Fetching databases
                result_data = fetch_databases(cluster_id, formatted_title, databricks_api_key)
                database_names = result_data['databaseName'].tolist()
                selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..")

                if selected_database:
                    # Fetching tables
                    result_data = fetch_tables(selected_database, cluster_id, formatted_title, databricks_api_key)
                    table_names = result_data['tableName'].tolist()
                    selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..")

                    if selected_table:
                        # Fetching columns
                        result_data = fetch_columns(selected_database, selected_table, cluster_id, formatted_title, databricks_api_key)
                        column_names = result_data['col_name'].tolist()

                        st.subheader("Map table schema to Labelbox schema", divider="grey")
                        # Your existing code to handle schema mapping...

                        # Fetch the first 5 rows of the selected table
                        with st.spinner('Fetching first 5 rows of the selected table...'):
                            query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;"
                            table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key)
                            st.write(table_sample_data)

                        # Define two columns for side-by-side selectboxes
                        col1, col2 = st.columns(2)

                        with col1:
                            selected_row_data = st.selectbox(
                                "row_data (required):", 
                                column_names, 
                                index=None, 
                                placeholder="Select a column..", 
                                help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox."
                            )

                        with col2:
                            selected_global_key = st.selectbox(
                                "global_key (optional):", 
                                column_names, 
                                index=None, 
                                placeholder="Select a column..", 
                                help="Select the column that contains the global key. If not provided, a new key will be generated for you."
                            )

                        # Fetch a single row from the selected table
                        query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;"
                        result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key)
                        
                        if selected_row_data:
                            # Extract the value from the selected row_data column
                            sample_row_data_value = result_sample[selected_row_data].iloc[0]

                            # Validate the extracted value
                            dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id
                            # Mode
                            mode = "preview" if is_preview else "production"

                            # Databricks instance and API key
                            databricks_instance = formatted_title
                            databricks_api_key = databricks_api_key

                            # Dataset ID and New Dataset
                            new_dataset = 1 if create_new_dataset else 0
                            dataset_id = dataset_id 

                            # Table Path
                            table_path = f"{selected_database}.{selected_table}"
                            # Frequency
                            frequency = frequency

                            # Schema Map
                            row_data_input = selected_row_data
                            global_key_input = selected_global_key

                            # Create the initial dictionary
                            schema_map_dict = {'row_data': row_data_input}
                            if global_key_input:
                                schema_map_dict['global_key'] = global_key_input

                            # Swap keys and values
                            reversed_schema_map_dict = {v: k for k, v in schema_map_dict.items()}

                            # Convert the reversed dictionary to a stringified JSON
                            reversed_schema_map_str = json.dumps(reversed_schema_map_dict)
                                                        
                    
                            data = {
                                "cloud": cloud,
                                "mode": mode,
                                "databricks_instance": databricks_instance,
                                "databricks_api_key": databricks_api_key,
                                "new_dataset": new_dataset,
                                "dataset_id": dataset_id,
                                "table_path": table_path,
                                "labelbox_api_key": labelbox_api_key,
                                "frequency": frequency,
                                "new_cluster": 0,
                                "cluster_id": cluster_id,
                                "schema_map": reversed_schema_map_str
                            }
                            

                            if st.button("Deploy Pipeline!", type="primary"):
                                # Ensure all fields are filled out
                                required_fields = [
                                    mode, databricks_instance, databricks_api_key, new_dataset, dataset_id,
                                    table_path, labelbox_api_key, frequency, cluster_id, reversed_schema_map_str
                                ]


                                # Sending a POST request to the Flask app endpoint
                                with st.spinner("Deploying pipeline..."):
                                    response = requests.post("https://us-central1-dbt-prod.cloudfunctions.net/deploy-databricks-pipeline", json=data)

                                # Check if request was successful
                                if response.status_code == 200:
                                    # Display the response using Streamlit
                                    st.balloons()
                                    response = response.json()
                                    # Extract the job_id
                                    job_id = response['message'].split('job_id":')[1].split('}')[0]
                                    from urllib.parse import urlparse, parse_qs

                                    # Parse the Databricks instance URL to extract the organization ID
                                    parsed_url = urlparse(formatted_title)
                                    query_params = parse_qs(parsed_url.query)
                                    organization_id = query_params.get("o", [""])[0]

                                    # Generate the Databricks Job URL
                                    job_url = f"http://{formatted_title}/?o={organization_id}#job/{job_id}"
                                    st.success(f"Pipeline deployed successfully! [{job_url}]({job_url}) 🚀")
                                else:
                                    st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="🚫")

st.markdown("""
<style>
/* Add a large bottom padding to the main content */
.main .block-container {
    padding-bottom: 1000px;  /* Adjust this value as needed */
}
</style>
""", unsafe_allow_html=True)