SATRANG / app.py
YashMK89's picture
update app.py
d2c85d1 verified
raw
history blame
30.8 kB
import streamlit as st
import json
import ee
import os
import pandas as pd
import geopandas as gpd
from datetime import datetime
import leafmap.foliumap as leafmap
import re
from shapely.geometry import base
from lxml import etree
from xml.etree import ElementTree as ET
# Set up the page layout
st.set_page_config(layout="wide")
# Custom button styling
m = st.markdown(
"""
<style>
div.stButton > button:first-child {
background-color: #006400;
color:#ffffff;
}
</style>""",
unsafe_allow_html=True,
)
# Logo
st.write(
f"""
<div style="display: flex; justify-content: space-between; align-items: center;">
<img src="https://huggingface.co/spaces/YashMK89/GEE_Calculator/resolve/main/ISRO_Logo.png" style="width: 20%; margin-right: auto;">
<img src="https://huggingface.co/spaces/YashMK89/GEE_Calculator/resolve/main/SAC_Logo.png" style="width: 20%; margin-left: auto;">
</div>
""",
unsafe_allow_html=True,
)
# Title
st.markdown(
f"""
<h1 style="text-align: center;">Precision Analysis for Vegetation, Water, and Air Quality</h1>
""",
unsafe_allow_html=True,
)
st.write("<h2><div style='text-align: center;'>User Inputs</div></h2>", unsafe_allow_html=True)
# Authenticate and initialize Earth Engine
earthengine_credentials = os.environ.get("EE_Authentication")
# Initialize Earth Engine with secret credentials
os.makedirs(os.path.expanduser("~/.config/earthengine/"), exist_ok=True)
with open(os.path.expanduser("~/.config/earthengine/credentials"), "w") as f:
f.write(earthengine_credentials)
ee.Initialize(project='ee-yashsacisro24')
# Load the Sentinel dataset options from JSON file
with open("sentinel_datasets.json") as f:
data = json.load(f)
# Display the title for the Streamlit app
st.title("Sentinel Dataset")
# Select dataset category (main selection)
main_selection = st.selectbox("Select Sentinel Dataset Category", list(data.keys()))
# If a category is selected, display the sub-options (specific datasets)
if main_selection:
sub_options = data[main_selection]["sub_options"]
sub_selection = st.selectbox("Select Specific Dataset ID", list(sub_options.keys()))
# Display the selected dataset ID based on user input
if sub_selection:
st.write(f"You selected: {main_selection} -> {sub_options[sub_selection]}")
st.write(f"Dataset ID: {sub_selection}")
dataset_id = sub_selection # Use the key directly as the dataset ID
# Earth Engine Index Calculator Section
st.header("Earth Engine Index Calculator")
# Load band information based on selected dataset
if main_selection and sub_selection:
dataset_bands = data[main_selection]["bands"].get(sub_selection, [])
st.write(f"Available Bands for {sub_options[sub_selection]}: {', '.join(dataset_bands)}")
# Allow user to select 1 or 2 bands
selected_bands = st.multiselect(
"Select 1 or 2 Bands for Calculation",
options=dataset_bands,
default=[dataset_bands[0]] if dataset_bands else [],
help="Select at least 1 band and up to 2 bands."
)
# Ensure minimum 1 and maximum 2 bands are selected
if len(selected_bands) < 1:
st.warning("Please select at least one band.")
st.stop()
elif len(selected_bands) > 2:
st.warning("You can select a maximum of 2 bands.")
st.stop()
# Show custom formula input if bands are selected
if selected_bands:
default_formula = (
f"{selected_bands[0]}" if len(selected_bands) == 1
else f"({selected_bands[0]} - {selected_bands[1]}) / ({selected_bands[0]} + {selected_bands[1]})"
)
custom_formula = st.text_input(
"Enter Custom Formula (e.g., 'B3*B5/2' or '(B8 - B4) / (B8 + B4)')",
value=default_formula,
help=f"Use {', '.join(selected_bands)} in your formula. Example: 'B3*B5/2'"
)
if not custom_formula:
st.warning("Please enter a custom formula to proceed.")
st.stop()
# Display the formula
st.write(f"Custom Formula: {custom_formula}")
# Function to get the corresponding reducer based on user input
def get_reducer(reducer_name):
"""
Map user-friendly reducer names to Earth Engine reducer objects.
"""
reducers = {
'mean': ee.Reducer.mean(),
'sum': ee.Reducer.sum(),
'median': ee.Reducer.median(),
'min': ee.Reducer.min(),
'max': ee.Reducer.max(),
'count': ee.Reducer.count(),
}
return reducers.get(reducer_name.lower(), ee.Reducer.mean())
# Streamlit selectbox for reducer choice
reducer_choice = st.selectbox(
"Select Reducer",
['mean', 'sum', 'median', 'min', 'max', 'count'],
index=0 # Default to 'mean'
)
# Function to convert geometry to Earth Engine format
def convert_to_ee_geometry(geometry):
if isinstance(geometry, base.BaseGeometry):
if geometry.is_valid:
geojson = geometry.__geo_interface__
return ee.Geometry(geojson)
else:
raise ValueError("Invalid geometry: The polygon geometry is not valid.")
elif isinstance(geometry, dict) or isinstance(geometry, str):
try:
if isinstance(geometry, str):
geometry = json.loads(geometry)
if 'type' in geometry and 'coordinates' in geometry:
return ee.Geometry(geometry)
else:
raise ValueError("GeoJSON format is invalid.")
except Exception as e:
raise ValueError(f"Error parsing GeoJSON: {e}")
elif isinstance(geometry, str) and geometry.lower().endswith(".kml"):
try:
tree = ET.parse(geometry)
kml_root = tree.getroot()
kml_namespace = {'kml': 'http://www.opengis.net/kml/2.2'}
coordinates = kml_root.findall(".//kml:coordinates", kml_namespace)
if coordinates:
coords_text = coordinates[0].text.strip()
coords = coords_text.split()
coords = [tuple(map(float, coord.split(','))) for coord in coords]
geojson = {"type": "Polygon", "coordinates": [coords]}
return ee.Geometry(geojson)
else:
raise ValueError("KML does not contain valid coordinates.")
except Exception as e:
raise ValueError(f"Error parsing KML: {e}")
else:
raise ValueError("Unsupported geometry input type. Supported types are Shapely, GeoJSON, and KML.")
# Date Input for Start and End Dates
start_date = st.date_input("Start Date", value=pd.to_datetime('2024-11-01'))
end_date = st.date_input("End Date", value=pd.to_datetime('2024-12-01'))
# Convert start_date and end_date to string format for Earth Engine
start_date_str = start_date.strftime('%Y-%m-%d')
end_date_str = end_date.strftime('%Y-%m-%d')
# Aggregation period selection
aggregation_period = st.selectbox("Select Aggregation Period", ["Daily", "Weekly", "Monthly", "Yearly"], index=0)
# Ask user whether they want to process 'Point' or 'Polygon' data
shape_type = st.selectbox("Do you want to process 'Point' or 'Polygon' data?", ["Point", "Polygon"])
# Additional options based on shape type
kernel_size = None
include_boundary = None
if shape_type.lower() == "point":
kernel_size = st.selectbox(
"Select Calculation Area",
["Point", "3x3 Kernel", "5x5 Kernel"],
index=0,
help="Choose 'Point' for exact point calculation, or a kernel size for area averaging."
)
elif shape_type.lower() == "polygon":
include_boundary = st.checkbox(
"Include Boundary Pixels",
value=True,
help="Check to include pixels on the polygon boundary; uncheck to exclude them."
)
# Ask user to upload a file based on shape type
file_upload = st.file_uploader(f"Upload your {shape_type} data (CSV, GeoJSON, KML)", type=["csv", "geojson", "kml"])
if file_upload is not None:
# Read the user-uploaded file
if shape_type.lower() == "point":
if file_upload.name.endswith('.csv'):
locations_df = pd.read_csv(file_upload)
elif file_upload.name.endswith('.geojson'):
locations_df = gpd.read_file(file_upload)
elif file_upload.name.endswith('.kml'):
locations_df = gpd.read_file(file_upload)
else:
st.error("Unsupported file format. Please upload CSV, GeoJSON, or KML.")
locations_df = pd.DataFrame()
if 'geometry' in locations_df.columns:
if locations_df.geometry.geom_type.isin(['Polygon', 'MultiPolygon']).any():
st.warning("The uploaded file contains polygon data. Please select 'Polygon' for processing.")
st.stop()
with st.spinner('Processing Map...'):
if locations_df is not None and not locations_df.empty:
if 'geometry' in locations_df.columns:
locations_df['latitude'] = locations_df['geometry'].y
locations_df['longitude'] = locations_df['geometry'].x
if 'latitude' not in locations_df.columns or 'longitude' not in locations_df.columns:
st.error("Uploaded file is missing required 'latitude' or 'longitude' columns.")
else:
st.write("Preview of the uploaded points data:")
st.dataframe(locations_df.head())
m = leafmap.Map(center=[locations_df['latitude'].mean(), locations_df['longitude'].mean()], zoom=10)
for _, row in locations_df.iterrows():
latitude = row['latitude']
longitude = row['longitude']
if pd.isna(latitude) or pd.isna(longitude):
continue
m.add_marker(location=[latitude, longitude], popup=row.get('name', 'No Name'))
st.write("Map of Uploaded Points:")
m.to_streamlit()
st.session_state.map_data = m
elif shape_type.lower() == "polygon":
if file_upload.name.endswith('.csv'):
locations_df = pd.read_csv(file_upload)
elif file_upload.name.endswith('.geojson'):
locations_df = gpd.read_file(file_upload)
elif file_upload.name.endswith('.kml'):
locations_df = gpd.read_file(file_upload)
else:
st.error("Unsupported file format. Please upload CSV, GeoJSON, or KML.")
locations_df = pd.DataFrame()
if 'geometry' in locations_df.columns:
if locations_df.geometry.geom_type.isin(['Point', 'MultiPoint']).any():
st.warning("The uploaded file contains point data. Please select 'Point' for processing.")
st.stop()
with st.spinner('Processing Map...'):
if locations_df is not None and not locations_df.empty:
if 'geometry' not in locations_df.columns:
st.error("Uploaded file is missing required 'geometry' column.")
else:
st.write("Preview of the uploaded polygons data:")
st.dataframe(locations_df.head())
centroid_lat = locations_df.geometry.centroid.y.mean()
centroid_lon = locations_df.geometry.centroid.x.mean()
m = leafmap.Map(center=[centroid_lat, centroid_lon], zoom=10)
for _, row in locations_df.iterrows():
polygon = row['geometry']
if polygon.is_valid:
gdf = gpd.GeoDataFrame([row], geometry=[polygon], crs=locations_df.crs)
m.add_gdf(gdf=gdf, layer_name=row.get('name', 'Unnamed Polygon'))
st.write("Map of Uploaded Polygons:")
m.to_streamlit()
st.session_state.map_data = m
# Initialize session state for storing results
if 'results' not in st.session_state:
st.session_state.results = []
if 'last_params' not in st.session_state:
st.session_state.last_params = {}
if 'map_data' not in st.session_state:
st.session_state.map_data = None
# Function to check if parameters have changed
def parameters_changed():
return (
st.session_state.last_params.get('main_selection') != main_selection or
st.session_state.last_params.get('dataset_id') != dataset_id or
st.session_state.last_params.get('selected_bands') != selected_bands or
st.session_state.last_params.get('custom_formula') != custom_formula or
st.session_state.last_params.get('start_date_str') != start_date_str or
st.session_state.last_params.get('end_date_str') != end_date_str or
st.session_state.last_params.get('shape_type') != shape_type or
st.session_state.last_params.get('file_upload') != file_upload or
st.session_state.last_params.get('kernel_size') != kernel_size or
st.session_state.last_params.get('include_boundary') != include_boundary
)
# If parameters have changed, reset the results
if parameters_changed():
st.session_state.results = []
st.session_state.last_params = {
'main_selection': main_selection,
'dataset_id': dataset_id,
'selected_bands': selected_bands,
'custom_formula': custom_formula,
'start_date_str': start_date_str,
'end_date_str': end_date_str,
'shape_type': shape_type,
'file_upload': file_upload,
'kernel_size': kernel_size,
'include_boundary': include_boundary
}
# Function to calculate custom formula using eval safely
def calculate_custom_formula(image, geometry, selected_bands, custom_formula, reducer_choice, scale=30):
try:
band_values = {}
for band in selected_bands:
band_names = image.bandNames().getInfo()
if band not in band_names:
raise ValueError(f"The band '{band}' does not exist in the image.")
band_values[band] = image.select(band)
reducer = get_reducer(reducer_choice)
reduced_values = {}
for band in selected_bands:
reduced_value = band_values[band].reduceRegion(
reducer=reducer,
geometry=geometry,
scale=scale
).get(band).getInfo()
if reduced_value is None:
reduced_value = 0
reduced_values[band] = float(reduced_value)
formula = custom_formula
for band in selected_bands:
formula = formula.replace(band, str(reduced_values[band]))
result = eval(formula, {"__builtins__": {}}, reduced_values)
if not isinstance(result, (int, float)):
raise ValueError("Formula evaluation did not result in a numeric value.")
return ee.Image.constant(result).rename('custom_result')
except ZeroDivisionError:
st.error("Error: Division by zero occurred in the formula.")
return ee.Image(0).rename('custom_result').set('error', 'Division by zero')
except SyntaxError:
st.error(f"Error: Invalid formula syntax in '{custom_formula}'.")
return ee.Image(0).rename('custom_result').set('error', 'Invalid syntax')
except ValueError as e:
st.error(f"Error: {str(e)}")
return ee.Image(0).rename('custom_result').set('error', str(e))
except Exception as e:
st.error(f"Unexpected error evaluating formula: {e}")
return ee.Image(0).rename('custom_result').set('error', str(e))
# Function to calculate index for a period
def calculate_index_for_period(image, roi, selected_bands, custom_formula, reducer_choice):
return calculate_custom_formula(image, roi, selected_bands, custom_formula, reducer_choice)
# Aggregation functions
def aggregate_data_daily(collection):
collection = collection.map(lambda image: image.set('day', ee.Date(image.get('system:time_start')).format('YYYY-MM-dd')))
grouped_by_day = collection.aggregate_array('day').distinct()
def calculate_daily_mean(day):
daily_collection = collection.filter(ee.Filter.eq('day', day))
daily_mean = daily_collection.mean()
return daily_mean.set('day', day)
daily_images = ee.List(grouped_by_day.map(calculate_daily_mean))
return ee.ImageCollection(daily_images)
def aggregate_data_weekly(collection):
def set_week_start(image):
date = ee.Date(image.get('system:time_start'))
days_since_week_start = date.getRelative('day', 'week')
offset = ee.Number(days_since_week_start).multiply(-1)
week_start = date.advance(offset, 'day')
return image.set('week_start', week_start.format('YYYY-MM-dd'))
collection = collection.map(set_week_start)
grouped_by_week = collection.aggregate_array('week_start').distinct()
def calculate_weekly_mean(week_start):
weekly_collection = collection.filter(ee.Filter.eq('week_start', week_start))
weekly_mean = weekly_collection.mean()
return weekly_mean.set('week_start', week_start)
weekly_images = ee.List(grouped_by_week.map(calculate_weekly_mean))
return ee.ImageCollection(weekly_images)
def aggregate_data_monthly(collection, start_date, end_date):
collection = collection.filterDate(start_date, end_date)
collection = collection.map(lambda image: image.set('month', ee.Date(image.get('system:time_start')).format('YYYY-MM')))
grouped_by_month = collection.aggregate_array('month').distinct()
def calculate_monthly_mean(month):
monthly_collection = collection.filter(ee.Filter.eq('month', month))
monthly_mean = monthly_collection.mean()
return monthly_mean.set('month', month)
monthly_images = ee.List(grouped_by_month.map(calculate_monthly_mean))
return ee.ImageCollection(monthly_images)
def aggregate_data_yearly(collection):
collection = collection.map(lambda image: image.set('year', ee.Date(image.get('system:time_start')).format('YYYY')))
grouped_by_year = collection.aggregate_array('year').distinct()
def calculate_yearly_mean(year):
yearly_collection = collection.filter(ee.Filter.eq('year', year))
yearly_mean = yearly_collection.mean()
return yearly_mean.set('year', year)
yearly_images = ee.List(grouped_by_year.map(calculate_yearly_mean))
return ee.ImageCollection(yearly_images)
# Process aggregation function with kernel and boundary options
def process_aggregation(locations_df, start_date_str, end_date_str, dataset_id, selected_bands, reducer_choice, shape_type, aggregation_period, custom_formula="", kernel_size=None, include_boundary=None):
aggregated_results = []
if not custom_formula:
st.error("Custom formula cannot be empty. Please provide a formula.")
return aggregated_results
total_steps = len(locations_df)
progress_bar = st.progress(0)
progress_text = st.empty()
with st.spinner('Processing data...'):
if shape_type.lower() == "point":
for idx, row in locations_df.iterrows():
latitude = row.get('latitude')
longitude = row.get('longitude')
if pd.isna(latitude) or pd.isna(longitude):
st.warning(f"Skipping location {idx} with missing latitude or longitude")
continue
location_name = row.get('name', f"Location_{idx}")
# Define the region of interest based on kernel size
if kernel_size == "3x3 Kernel":
# Assuming 30m resolution, 3x3 kernel = 90m x 90m
buffer_size = 45 # Half of 90m to center the square
roi = ee.Geometry.Point([longitude, latitude]).buffer(buffer_size).bounds()
elif kernel_size == "5x5 Kernel":
# 5x5 kernel = 150m x 150m
buffer_size = 75 # Half of 150m
roi = ee.Geometry.Point([longitude, latitude]).buffer(buffer_size).bounds()
else: # Point
roi = ee.Geometry.Point([longitude, latitude])
collection = ee.ImageCollection(dataset_id) \
.filterDate(ee.Date(start_date_str), ee.Date(end_date_str)) \
.filterBounds(roi)
if aggregation_period.lower() == 'daily':
collection = aggregate_data_daily(collection)
elif aggregation_period.lower() == 'weekly':
collection = aggregate_data_weekly(collection)
elif aggregation_period.lower() == 'monthly':
collection = aggregate_data_monthly(collection, start_date_str, end_date_str)
elif aggregation_period.lower() == 'yearly':
collection = aggregate_data_yearly(collection)
image_list = collection.toList(collection.size())
processed_weeks = set()
for i in range(image_list.size().getInfo()):
image = ee.Image(image_list.get(i))
if aggregation_period.lower() == 'daily':
timestamp = image.get('day')
period_label = 'Date'
date = ee.Date(timestamp).format('YYYY-MM-dd').getInfo()
elif aggregation_period.lower() == 'weekly':
timestamp = image.get('week_start')
period_label = 'Week'
date = ee.String(timestamp).getInfo()
if (pd.to_datetime(date) < pd.to_datetime(start_date_str) or
pd.to_datetime(date) > pd.to_datetime(end_date_str) or
date in processed_weeks):
continue
processed_weeks.add(date)
elif aggregation_period.lower() == 'monthly':
timestamp = image.get('month')
period_label = 'Month'
date = ee.Date(timestamp).format('YYYY-MM').getInfo()
elif aggregation_period.lower() == 'yearly':
timestamp = image.get('year')
period_label = 'Year'
date = ee.Date(timestamp).format('YYYY').getInfo()
index_image = calculate_index_for_period(image, roi, selected_bands, custom_formula, reducer_choice)
try:
index_value = index_image.reduceRegion(
reducer=get_reducer(reducer_choice),
geometry=roi,
scale=30
).get('custom_result')
calculated_value = index_value.getInfo()
if isinstance(calculated_value, (int, float)):
aggregated_results.append({
'Location Name': location_name,
'Latitude': latitude,
'Longitude': longitude,
period_label: date,
'Start Date': start_date_str,
'End Date': end_date_str,
'Calculated Value': calculated_value
})
else:
st.warning(f"Skipping invalid value for {location_name} on {date}")
except Exception as e:
st.error(f"Error retrieving value for {location_name}: {e}")
progress_percentage = (idx + 1) / total_steps
progress_bar.progress(progress_percentage)
progress_text.markdown(f"Processing: {int(progress_percentage * 100)}%")
elif shape_type.lower() == "polygon":
for idx, row in locations_df.iterrows():
polygon_name = row.get('name', f"Polygon_{idx}")
polygon_geometry = row.get('geometry')
location_name = polygon_name
try:
roi = convert_to_ee_geometry(polygon_geometry)
if not include_boundary:
# Erode the polygon by a small buffer (e.g., 1 pixel = 30m) to exclude boundary
roi = roi.buffer(-30).bounds()
except ValueError as e:
st.warning(f"Skipping invalid polygon {polygon_name}: {e}")
continue
collection = ee.ImageCollection(dataset_id) \
.filterDate(ee.Date(start_date_str), ee.Date(end_date_str)) \
.filterBounds(roi)
if aggregation_period.lower() == 'daily':
collection = aggregate_data_daily(collection)
elif aggregation_period.lower() == 'weekly':
collection = aggregate_data_weekly(collection)
elif aggregation_period.lower() == 'monthly':
collection = aggregate_data_monthly(collection, start_date_str, end_date_str)
elif aggregation_period.lower() == 'yearly':
collection = aggregate_data_yearly(collection)
image_list = collection.toList(collection.size())
processed_weeks = set()
for i in range(image_list.size().getInfo()):
image = ee.Image(image_list.get(i))
if aggregation_period.lower() == 'daily':
timestamp = image.get('day')
period_label = 'Date'
date = ee.Date(timestamp).format('YYYY-MM-dd').getInfo()
elif aggregation_period.lower() == 'weekly':
timestamp = image.get('week_start')
period_label = 'Week'
date = ee.String(timestamp).getInfo()
if (pd.to_datetime(date) < pd.to_datetime(start_date_str) or
pd.to_datetime(date) > pd.to_datetime(end_date_str) or
date in processed_weeks):
continue
processed_weeks.add(date)
elif aggregation_period.lower() == 'monthly':
timestamp = image.get('month')
period_label = 'Month'
date = ee.Date(timestamp).format('YYYY-MM').getInfo()
elif aggregation_period.lower() == 'yearly':
timestamp = image.get('year')
period_label = 'Year'
date = ee.Date(timestamp).format('YYYY').getInfo()
index_image = calculate_index_for_period(image, roi, selected_bands, custom_formula, reducer_choice)
try:
index_value = index_image.reduceRegion(
reducer=get_reducer(reducer_choice),
geometry=roi,
scale=30
).get('custom_result')
calculated_value = index_value.getInfo()
if isinstance(calculated_value, (int, float)):
aggregated_results.append({
'Location Name': location_name,
period_label: date,
'Start Date': start_date_str,
'End Date': end_date_str,
'Calculated Value': calculated_value
})
else:
st.warning(f"Skipping invalid value for {location_name} on {date}")
except Exception as e:
st.error(f"Error retrieving value for {location_name}: {e}")
progress_percentage = (idx + 1) / total_steps
progress_bar.progress(progress_percentage)
progress_text.markdown(f"Processing: {int(progress_percentage * 100)}%")
if aggregated_results:
result_df = pd.DataFrame(aggregated_results)
if aggregation_period.lower() == 'daily':
agg_dict = {
'Start Date': 'first',
'End Date': 'first',
'Calculated Value': 'mean'
}
if shape_type.lower() == 'point':
agg_dict['Latitude'] = 'first'
agg_dict['Longitude'] = 'first'
aggregated_output = result_df.groupby('Location Name').agg(agg_dict).reset_index()
aggregated_output.rename(columns={'Calculated Value': 'Aggregated Value'}, inplace=True)
return aggregated_output.to_dict(orient='records')
else:
return result_df.to_dict(orient='records')
return []
# Button to trigger calculation
if st.button(f"Calculate({custom_formula})"):
if file_upload is not None:
if shape_type.lower() in ["point", "polygon"]:
results = process_aggregation(
locations_df,
start_date_str,
end_date_str,
dataset_id,
selected_bands,
reducer_choice,
shape_type,
aggregation_period,
custom_formula,
kernel_size=kernel_size,
include_boundary=include_boundary
)
if results:
result_df = pd.DataFrame(results)
st.write(f"Processed Results Table ({aggregation_period}):")
st.dataframe(result_df)
filename = f"{main_selection}_{dataset_id}_{start_date.strftime('%Y/%m/%d')}_{end_date.strftime('%Y/%m/%d')}_{aggregation_period.lower()}.csv"
st.download_button(
label="Download results as CSV",
data=result_df.to_csv(index=False).encode('utf-8'),
file_name=filename,
mime='text/csv'
)
st.spinner('')
st.success('Processing complete!')
else:
st.warning("No results were generated.")
else:
st.warning("Please upload a file.")