import os
import ee
import geemap
import json
import geopandas as gpd
import streamlit as st
import pandas as pd
from fastkml import kml
import geojson

ee_credentials = os.environ.get("EE")
os.makedirs(os.path.expanduser("~/.config/earthengine/"), exist_ok=True)
with open(os.path.expanduser("~/.config/earthengine/credentials"), "w") as f:
    f.write(ee_credentials)

ee.Initialize()

def convert_3d_to_2d(geometry):
    """
    Recursively convert any 3D coordinates in a geometry to 2D.
    """
    if geometry.is_empty:
        return geometry

    if geometry.geom_type == 'Polygon':
        return geojson.Polygon([[(x, y) for x, y, *_ in ring] for ring in geometry.coordinates])

    elif geometry.geom_type == 'MultiPolygon':
        return geojson.MultiPolygon([
            [[(x, y) for x, y, *_ in ring] for ring in poly]
            for poly in geometry.coordinates
        ])

    elif geometry.geom_type == 'LineString':
        return geojson.LineString([(x, y) for x, y, *_ in geometry.coordinates])

    elif geometry.geom_type == 'MultiLineString':
        return geojson.MultiLineString([
            [(x, y) for x, y, *_ in line]
            for line in geometry.coordinates
        ])

    elif geometry.geom_type == 'Point':
        x, y, *_ = geometry.coordinates
        return geojson.Point((x, y))

    elif geometry.geom_type == 'MultiPoint':
        return geojson.MultiPoint([(x, y) for x, y, *_ in geometry.coordinates])

    return geometry  # Return unchanged if not a supported geometry type

def kml_to_geojson(kml_string):
    k = kml.KML()
    k.from_string(kml_string.encode('utf-8'))  # Convert the string to bytes
    features = list(k.features())
    
    geojson_features = []
    for feature in features:
        geometry_2d = convert_3d_to_2d(feature.geometry)
        geojson_features.append(geojson.Feature(geometry=geometry_2d))
    
    geojson_data = geojson.FeatureCollection(geojson_features)
    return geojson_data

def geojson_to_ee(geojson_data):
  ee_object = geemap.geojson_to_ee(geojson_data)
  return ee_object

# put title in center
st.markdown("""
<style>
h1 {
    text-align: center;
}
</style>
""", unsafe_allow_html=True)

st.title("Mean NDVI Calculator")

# get the start and end date from the user
col = st.columns(2)
start_date = col[0].date_input("Start Date", value=pd.to_datetime('2021-01-01'))
end_date = col[1].date_input("End Date", value=pd.to_datetime('2021-01-30'))
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")

max_cloud_cover = st.number_input("Max Cloud Cover", value=20)

# Get the geojson file from the user
uploaded_file = st.file_uploader("Upload KML/GeoJSON file", type=["geojson", "kml"])

# Read the KML file
if uploaded_file is None:
    file_name = "Bhankhara_Df_11_he_5_2020-21.geojson"
    st.write(f"Using default file: {file_name}")
    data = gpd.read_file(file_name)
    with open(file_name) as f:
      str_data = f.read()
else:
    st.write(f"Using uploaded file: {uploaded_file.name}")
    file_name = uploaded_file.name
    bytes_data = uploaded_file.getvalue()
    str_data = bytes_data.decode("utf-8")


if file_name.endswith(".geojson"):
  geojson_data = json.loads(str_data)
elif file_name.endswith(".kml"):
  geojson_data = kml_to_geojson(str_data)
  print(geojson_data)

# Read Geojson File
ee_object = geojson_to_ee(geojson_data)

# Filter data based on the date, bounds, cloud coverage and select NIR and Red Band
collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(ee_object).filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', max_cloud_cover)).filter(ee.Filter.date(start_date, end_date)).select(['B4', 'B8'])

# Print Number of Images in collection
# print("Number of images", collection.size().getInfo())
st.write(f"Number of images: {collection.size().getInfo()}")

# Calculate NDVI as Normalized Index
def calculate_ndvi(image):
  ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI')
  return image.addBands(ndvi)

collection = collection.map(calculate_ndvi)

# Write Zonalstats into csv file
# out_dir = os.path.join("Output")
# out_NDVI_stats = os.path.join(out_dir, "tmp.csv")

# if not os.path.exists(out_dir):
#     os.makedirs(out_dir)

geemap.zonal_stats(collection.select(["NDVI"]), ee_object, "tmp.csv", stat_type="mean", scale=10)

# Show the table
df = pd.read_csv("tmp.csv")
df = df.T
df = df.reset_index()
df = df.iloc[:-2]
df['index'] = pd.to_datetime(df['index'].apply(lambda x: x.split('_')[1].split('T')[0])).dt.strftime('%Y-%m-%d')
df.rename(columns={'index': 'Date', 0: 'Mean NDVI'}, inplace=True)
st.write(df)

# plot the time series
st.write("Time Series Plot")
st.line_chart(df.set_index('Date'))

st.write(f"Overall Mean NDVI: {df['Mean NDVI'].mean():.2f}")