Spaces:
Running
Running
import streamlit as st | |
import json | |
import ee | |
import os | |
import pandas as pd | |
import geopandas as gpd | |
from datetime import datetime | |
import leafmap.foliumap as leafmap | |
import re | |
from shapely.geometry import base | |
from xml.etree import ElementTree as XET | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import time | |
import matplotlib.pyplot as plt | |
import plotly.express as px | |
# Set up the page layout | |
st.set_page_config(layout="wide") | |
# Custom button styling | |
m = st.markdown( | |
""" | |
<style> | |
div.stButton > button:first-child { | |
background-color: #006400; | |
color:#ffffff; | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
# Logo and Title | |
st.write( | |
f""" | |
<div style="display: flex; justify-content: space-between; align-items: center;"> | |
<img src="https://huggingface.co/spaces/YashMK89/GEE_Calculator/resolve/main/ISRO_Logo.png" style="width: 20%; margin-right: auto;"> | |
<img src="https://huggingface.co/spaces/YashMK89/GEE_Calculator/resolve/main/SAC_Logo.png" style="width: 20%; margin-left: auto;"> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
f""" | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<img src="https://huggingface.co/spaces/YashMK89/GEE_Calculator/resolve/main/SATRANG.png" style="width: 30%;"> | |
<h3 style="text-align: center; margin: 0;">( Spatial and Temporal Aggregation for Remote-sensing Analysis of GEE Data )</h3> | |
</div> | |
<hr> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Authenticate and initialize Earth Engine | |
earthengine_credentials = os.environ.get("EE_Authentication") | |
os.makedirs(os.path.expanduser("~/.config/earthengine/"), exist_ok=True) | |
with open(os.path.expanduser("~/.config/earthengine/credentials"), "w") as f: | |
f.write(earthengine_credentials) | |
ee.Initialize(project='ee-yashsacisro24') | |
# Helper function to get reducer | |
def get_reducer(reducer_name): | |
reducers = { | |
'mean': ee.Reducer.mean(), | |
'sum': ee.Reducer.sum(), | |
'median': ee.Reducer.median(), | |
'min': ee.Reducer.min(), | |
'max': ee.Reducer.max(), | |
'count': ee.Reducer.count(), | |
} | |
return reducers.get(reducer_name.lower(), ee.Reducer.mean()) | |
# Function to convert geometry to Earth Engine format | |
def convert_to_ee_geometry(geometry): | |
if isinstance(geometry, base.BaseGeometry): | |
if geometry.is_valid: | |
geojson = geometry.__geo_interface__ | |
return ee.Geometry(geojson) | |
else: | |
raise ValueError("Invalid geometry: The polygon geometry is not valid.") | |
elif isinstance(geometry, dict) or isinstance(geometry, str): | |
try: | |
if isinstance(geometry, str): | |
geometry = json.loads(geometry) | |
if 'type' in geometry and 'coordinates' in geometry: | |
return ee.Geometry(geometry) | |
else: | |
raise ValueError("GeoJSON format is invalid.") | |
except Exception as e: | |
raise ValueError(f"Error parsing GeoJSON: {e}") | |
elif isinstance(geometry, str) and geometry.lower().endswith(".kml"): | |
try: | |
tree = XET.parse(geometry) | |
kml_root = tree.getroot() | |
kml_namespace = {'kml': 'http://www.opengis.net/kml/2.2'} | |
coordinates = kml_root.findall(".//kml:coordinates", kml_namespace) | |
if coordinates: | |
coords_text = coordinates[0].text.strip() | |
coords = coords_text.split() | |
coords = [tuple(map(float, coord.split(','))) for coord in coords] | |
geojson = {"type": "Polygon", "coordinates": [coords]} | |
return ee.Geometry(geojson) | |
else: | |
raise ValueError("KML does not contain valid coordinates.") | |
except Exception as e: | |
raise ValueError(f"Error parsing KML: {e}") | |
else: | |
raise ValueError("Unsupported geometry input type. Supported types are Shapely, GeoJSON, and KML.") | |
# Function to calculate custom formula | |
def calculate_custom_formula(image, geometry, selected_bands, custom_formula, reducer_choice, scale=30): | |
try: | |
band_values = {} | |
band_names = image.bandNames().getInfo() | |
for band in selected_bands: | |
if band not in band_names: | |
raise ValueError(f"Band '{band}' not found in the dataset.") | |
band_values[band] = image.select(band) | |
reducer = get_reducer(reducer_choice) | |
reduced_values = {} | |
for band in selected_bands: | |
value = band_values[band].reduceRegion( | |
reducer=reducer, | |
geometry=geometry, | |
scale=scale | |
).get(band).getInfo() | |
reduced_values[band] = float(value if value is not None else 0) | |
formula = custom_formula | |
for band in selected_bands: | |
formula = formula.replace(band, str(reduced_values[band])) | |
result = eval(formula, {"__builtins__": {}}, reduced_values) | |
if not isinstance(result, (int, float)): | |
raise ValueError("Formula did not result in a numeric value.") | |
return ee.Image.constant(result).rename('custom_result') | |
except ZeroDivisionError: | |
st.error("Error: Division by zero in the formula.") | |
return ee.Image(0).rename('custom_result').set('error', 'Division by zero') | |
except SyntaxError: | |
st.error(f"Error: Invalid syntax in formula '{custom_formula}'.") | |
return ee.Image(0).rename('custom_result').set('error', 'Invalid syntax') | |
except ValueError as e: | |
st.error(f"Error: {str(e)}") | |
return ee.Image(0).rename('custom_result').set('error', str(e)) | |
except Exception as e: | |
st.error(f"Unexpected error: {e}") | |
return ee.Image(0).rename('custom_result').set('error', str(e)) | |
# Aggregation functions | |
def aggregate_data_custom(collection): | |
collection = collection.map(lambda image: image.set('day', ee.Date(image.get('system:time_start')).format('YYYY-MM-dd'))) | |
grouped_by_day = collection.aggregate_array('day').distinct() | |
def calculate_daily_mean(day): | |
daily_collection = collection.filter(ee.Filter.eq('day', day)) | |
daily_mean = daily_collection.mean() | |
return daily_mean.set('day', day) | |
daily_images = ee.List(grouped_by_day.map(calculate_daily_mean)) | |
return ee.ImageCollection(daily_images) | |
def aggregate_data_daily(collection): | |
""" | |
Aggregates data on a daily basis. | |
""" | |
def set_day_start(image): | |
date = ee.Date(image.get('system:time_start')) | |
day_start = date.format('YYYY-MM-dd') | |
return image.set('day_start', day_start) | |
collection = collection.map(set_day_start) | |
grouped_by_day = collection.aggregate_array('day_start').distinct() | |
def calculate_daily_mean(day_start): | |
daily_collection = collection.filter(ee.Filter.eq('day_start', day_start)) | |
daily_mean = daily_collection.mean() | |
return daily_mean.set('day_start', day_start) | |
daily_images = ee.List(grouped_by_day.map(calculate_daily_mean)) | |
return ee.ImageCollection(daily_images) | |
def aggregate_data_weekly(collection, start_date_str, end_date_str): | |
""" | |
Aggregates data on a weekly basis, starting from the exact start date provided by the user. | |
""" | |
start_date = ee.Date(start_date_str) | |
end_date = ee.Date(end_date_str) | |
# Calculate the number of weeks between the start and end dates | |
days_diff = end_date.difference(start_date, 'day') | |
num_weeks = days_diff.divide(7).ceil().getInfo() # Total number of weeks | |
weekly_images = [] | |
for week in range(num_weeks): | |
week_start = start_date.advance(week * 7, 'day') # Start of the week | |
week_end = week_start.advance(7, 'day') # End of the week | |
weekly_collection = collection.filterDate(week_start, week_end) | |
if weekly_collection.size().getInfo() > 0: | |
weekly_mean = weekly_collection.mean() | |
weekly_mean = weekly_mean.set('week_start', week_start.format('YYYY-MM-dd')) | |
weekly_images.append(weekly_mean) | |
return ee.ImageCollection.fromImages(weekly_images) | |
def aggregate_data_monthly(collection, start_date, end_date): | |
collection = collection.filterDate(start_date, end_date) | |
collection = collection.map(lambda image: image.set('month', ee.Date(image.get('system:time_start')).format('YYYY-MM'))) | |
grouped_by_month = collection.aggregate_array('month').distinct() | |
def calculate_monthly_mean(month): | |
monthly_collection = collection.filter(ee.Filter.eq('month', month)) | |
monthly_mean = monthly_collection.mean() | |
return monthly_mean.set('month', month) | |
monthly_images = ee.List(grouped_by_month.map(calculate_monthly_mean)) | |
return ee.ImageCollection(monthly_images) | |
def aggregate_data_yearly(collection): | |
collection = collection.map(lambda image: image.set('year', ee.Date(image.get('system:time_start')).format('YYYY'))) | |
grouped_by_year = collection.aggregate_array('year').distinct() | |
def calculate_yearly_mean(year): | |
yearly_collection = collection.filter(ee.Filter.eq('year', year)) | |
yearly_mean = yearly_collection.mean() | |
return yearly_mean.set('year', year) | |
yearly_images = ee.List(grouped_by_year.map(calculate_yearly_mean)) | |
return ee.ImageCollection(yearly_images) | |
def calculate_cloud_percentage(image, cloud_band='QA60'): | |
""" | |
Calculate the percentage of cloud-covered pixels in an image using the QA60 bitmask. | |
Assumes the presence of the QA60 cloud mask band. | |
""" | |
# Decode the QA60 bitmask | |
qa60 = image.select(cloud_band) | |
opaque_clouds = qa60.bitwiseAnd(1 << 10) # Bit 10: Opaque clouds | |
cirrus_clouds = qa60.bitwiseAnd(1 << 11) # Bit 11: Cirrus clouds | |
# Combine both cloud types into a single cloud mask | |
cloud_mask = opaque_clouds.Or(cirrus_clouds) | |
# Count total pixels and cloudy pixels | |
total_pixels = qa60.reduceRegion( | |
reducer=ee.Reducer.count(), | |
geometry=image.geometry(), | |
scale=60, # QA60 resolution is 60 meters | |
maxPixels=1e13 | |
).get(cloud_band) | |
cloudy_pixels = cloud_mask.reduceRegion( | |
reducer=ee.Reducer.sum(), | |
geometry=image.geometry(), | |
scale=60, # QA60 resolution is 60 meters | |
maxPixels=1e13 | |
).get(cloud_band) | |
# Calculate cloud percentage | |
if total_pixels == 0: | |
return 0 # Avoid division by zero | |
return ee.Number(cloudy_pixels).divide(ee.Number(total_pixels)).multiply(100) | |
# Preprocessing function with cloud filtering | |
def preprocess_collection(collection, cloud_threshold): | |
""" | |
Apply cloud filtering to the image collection using the QA60 bitmask. | |
- Tile-based filtering: Exclude tiles with cloud coverage exceeding the selected threshold. | |
- Pixel-based filtering: Mask out individual cloudy pixels. | |
""" | |
def filter_tile(image): | |
# Calculate cloud percentage for the tile | |
cloud_percentage = calculate_cloud_percentage(image, cloud_band='QA60') | |
# Keep the tile only if cloud percentage is below the threshold | |
return image.set('cloud_percentage', cloud_percentage).updateMask(cloud_percentage.lt(cloud_threshold)) | |
def mask_cloudy_pixels(image): | |
# Decode the QA60 bitmask | |
qa60 = image.select('QA60') | |
opaque_clouds = qa60.bitwiseAnd(1 << 10) # Bit 10: Opaque clouds | |
cirrus_clouds = qa60.bitwiseAnd(1 << 11) # Bit 11: Cirrus clouds | |
# Combine both cloud types into a single cloud mask | |
cloud_mask = opaque_clouds.Or(cirrus_clouds) | |
# Mask out cloudy pixels | |
clear_pixels = cloud_mask.Not() # Invert the mask to keep clear pixels | |
return image.updateMask(clear_pixels) | |
# Step 1: Apply tile-based filtering | |
filtered_collection = collection.map(filter_tile) | |
# Step 2: Apply pixel-based filtering | |
masked_collection = filtered_collection.map(mask_cloudy_pixels) | |
return masked_collection | |
# Worker function for processing a single geometry | |
def process_single_geometry(row, start_date_str, end_date_str, dataset_id, selected_bands, reducer_choice, shape_type, aggregation_period, custom_formula, original_lat_col, original_lon_col, kernel_size=None, include_boundary=None): | |
if shape_type.lower() == "point": | |
latitude = row.get('latitude') | |
longitude = row.get('longitude') | |
if pd.isna(latitude) or pd.isna(longitude): | |
return None # Skip invalid points | |
location_name = row.get('name', f"Location_{row.name}") | |
if kernel_size == "3x3 Kernel": | |
buffer_size = 45 # 90m x 90m | |
roi = ee.Geometry.Point([longitude, latitude]).buffer(buffer_size).bounds() | |
elif kernel_size == "5x5 Kernel": | |
buffer_size = 75 # 150m x 150m | |
roi = ee.Geometry.Point([longitude, latitude]).buffer(buffer_size).bounds() | |
else: # Point | |
roi = ee.Geometry.Point([longitude, latitude]) | |
elif shape_type.lower() == "polygon": | |
polygon_geometry = row.get('geometry') | |
location_name = row.get('name', f"Polygon_{row.name}") | |
try: | |
roi = convert_to_ee_geometry(polygon_geometry) | |
if not include_boundary: | |
roi = roi.buffer(-30).bounds() | |
except ValueError: | |
return None # Skip invalid polygons | |
# Filter and aggregate the image collection | |
collection = ee.ImageCollection(dataset_id) \ | |
.filterDate(ee.Date(start_date_str), ee.Date(end_date_str)) \ | |
.filterBounds(roi) | |
if aggregation_period.lower() == 'custom (start date to end date)': | |
collection = aggregate_data_custom(collection) | |
elif aggregation_period.lower() == 'daily': | |
collection = aggregate_data_daily(collection) | |
elif aggregation_period.lower() == 'weekly': | |
collection = aggregate_data_weekly(collection, start_date_str, end_date_str) | |
elif aggregation_period.lower() == 'monthly': | |
collection = aggregate_data_monthly(collection, start_date_str, end_date_str) | |
elif aggregation_period.lower() == 'yearly': | |
collection = aggregate_data_yearly(collection) | |
# Process each image in the collection | |
image_list = collection.toList(collection.size()) | |
processed_weeks = set() | |
aggregated_results = [] | |
for i in range(image_list.size().getInfo()): | |
image = ee.Image(image_list.get(i)) | |
if aggregation_period.lower() == 'custom (start date to end date)': | |
timestamp = image.get('day') | |
period_label = 'Date' | |
date = ee.Date(timestamp).format('YYYY-MM-dd').getInfo() | |
elif aggregation_period.lower() == 'daily': | |
timestamp = image.get('day_start') | |
period_label = 'Date' | |
date = ee.String(timestamp).getInfo() | |
elif aggregation_period.lower() == 'weekly': | |
timestamp = image.get('week_start') | |
period_label = 'Week' | |
date = ee.String(timestamp).getInfo() | |
if (pd.to_datetime(date) < pd.to_datetime(start_date_str) or | |
pd.to_datetime(date) > pd.to_datetime(end_date_str) or | |
date in processed_weeks): | |
continue | |
processed_weeks.add(date) | |
elif aggregation_period.lower() == 'monthly': | |
timestamp = image.get('month') | |
period_label = 'Month' | |
date = ee.Date(timestamp).format('YYYY-MM').getInfo() | |
elif aggregation_period.lower() == 'yearly': | |
timestamp = image.get('year') | |
period_label = 'Year' | |
date = ee.Date(timestamp).format('YYYY').getInfo() | |
index_image = calculate_custom_formula(image, roi, selected_bands, custom_formula, reducer_choice, scale=30) | |
try: | |
index_value = index_image.reduceRegion( | |
reducer=get_reducer(reducer_choice), | |
geometry=roi, | |
scale=30 | |
).get('custom_result') | |
calculated_value = index_value.getInfo() | |
if isinstance(calculated_value, (int, float)): | |
result = { | |
'Location Name': location_name, | |
period_label: date, | |
'Start Date': start_date_str, | |
'End Date': end_date_str, | |
'Calculated Value': calculated_value | |
} | |
if shape_type.lower() == 'point': | |
result[original_lat_col] = latitude # Use original column name | |
result[original_lon_col] = longitude # Use original column name | |
aggregated_results.append(result) | |
except Exception as e: | |
st.error(f"Error retrieving value for {location_name}: {e}") | |
return aggregated_results | |
# Main processing function | |
def process_aggregation(locations_df, start_date_str, end_date_str, dataset_id, selected_bands, reducer_choice, shape_type, aggregation_period, original_lat_col, original_lon_col, custom_formula="", kernel_size=None, include_boundary=None, cloud_threshold=0): | |
aggregated_results = [] | |
total_steps = len(locations_df) | |
progress_bar = st.progress(0) | |
progress_text = st.empty() | |
start_time = time.time() # Start timing the process | |
# Preprocess the image collection with cloud filtering | |
raw_collection = ee.ImageCollection(dataset_id) \ | |
.filterDate(ee.Date(start_date_str), ee.Date(end_date_str)) | |
# Print the size of the original collection | |
st.write(f"Original Collection Size: {raw_collection.size().getInfo()}") | |
# Apply cloud filtering if threshold > 0 | |
if cloud_threshold > 0: | |
raw_collection = preprocess_collection(raw_collection, cloud_threshold) | |
# Print the size of the preprocessed collection | |
st.write(f"Preprocessed Collection Size: {raw_collection.size().getInfo()}") | |
with ThreadPoolExecutor(max_workers=10) as executor: | |
futures = [] | |
for idx, row in locations_df.iterrows(): | |
future = executor.submit( | |
process_single_geometry, | |
row, | |
start_date_str, | |
end_date_str, | |
dataset_id, | |
selected_bands, | |
reducer_choice, | |
shape_type, | |
aggregation_period, | |
custom_formula, | |
original_lat_col, | |
original_lon_col, | |
kernel_size, | |
include_boundary | |
) | |
futures.append(future) | |
completed = 0 | |
for future in as_completed(futures): | |
result = future.result() | |
if result: | |
aggregated_results.extend(result) | |
completed += 1 | |
progress_percentage = completed / total_steps | |
progress_bar.progress(progress_percentage) | |
progress_text.markdown(f"Processing: {int(progress_percentage * 100)}%") | |
# End timing the process | |
end_time = time.time() | |
processing_time = end_time - start_time # Calculate total processing time | |
if aggregated_results: | |
result_df = pd.DataFrame(aggregated_results) | |
if aggregation_period.lower() == 'custom (start date to end date)': | |
agg_dict = { | |
'Start Date': 'first', | |
'End Date': 'first', | |
'Calculated Value': 'mean' | |
} | |
if shape_type.lower() == 'point': | |
agg_dict[original_lat_col] = 'first' | |
agg_dict[original_lon_col] = 'first' | |
aggregated_output = result_df.groupby('Location Name').agg(agg_dict).reset_index() | |
aggregated_output['Date Range'] = aggregated_output['Start Date'] + " to " + aggregated_output['End Date'] | |
aggregated_output.rename(columns={'Calculated Value': 'Aggregated Value'}, inplace=True) | |
return aggregated_output.to_dict(orient='records'), processing_time | |
else: | |
return result_df.to_dict(orient='records'), processing_time | |
return [], processing_time | |
# Streamlit App Logic | |
st.markdown("<h5>Image Collection</h5>", unsafe_allow_html=True) | |
imagery_base = st.selectbox("Select Imagery Base", ["Sentinel", "Landsat", "MODIS", "Custom Input"], index=0) | |
# Initialize data as an empty dictionary | |
data = {} | |
if imagery_base == "Sentinel": | |
dataset_file = "sentinel_datasets.json" | |
try: | |
with open(dataset_file) as f: | |
data = json.load(f) | |
except FileNotFoundError: | |
st.error(f"Dataset file '{dataset_file}' not found.") | |
data = {} | |
elif imagery_base == "Landsat": | |
dataset_file = "landsat_datasets.json" | |
try: | |
with open(dataset_file) as f: | |
data = json.load(f) | |
except FileNotFoundError: | |
st.error(f"Dataset file '{dataset_file}' not found.") | |
data = {} | |
elif imagery_base == "MODIS": | |
dataset_file = "modis_datasets.json" | |
try: | |
with open(dataset_file) as f: | |
data = json.load(f) | |
except FileNotFoundError: | |
st.error(f"Dataset file '{dataset_file}' not found.") | |
data = {} | |
elif imagery_base == "Custom Input": | |
custom_dataset_id = st.text_input("Enter Custom Earth Engine Dataset ID (e.g., AHN/AHN4)", value="") | |
if custom_dataset_id: | |
try: | |
if custom_dataset_id.startswith("ee.ImageCollection("): | |
custom_dataset_id = custom_dataset_id.replace("ee.ImageCollection('", "").replace("')", "") | |
collection = ee.ImageCollection(custom_dataset_id) | |
band_names = collection.first().bandNames().getInfo() | |
data = { | |
f"Custom Dataset: {custom_dataset_id}": { | |
"sub_options": {custom_dataset_id: f"Custom Dataset ({custom_dataset_id})"}, | |
"bands": {custom_dataset_id: band_names} | |
} | |
} | |
st.write(f"Fetched bands for {custom_dataset_id}: {', '.join(band_names)}") | |
except Exception as e: | |
st.error(f"Error fetching dataset: {str(e)}. Please check the dataset ID and ensure it's valid in Google Earth Engine.") | |
data = {} | |
else: | |
st.warning("Please enter a custom dataset ID to proceed.") | |
data = {} | |
if not data: | |
st.error("No valid dataset available. Please check your inputs.") | |
st.stop() | |
st.markdown("<hr><h5><b>{}</b></h5>".format(imagery_base), unsafe_allow_html=True) | |
main_selection = st.selectbox(f"Select {imagery_base} Dataset Category", list(data.keys())) | |
sub_selection = None | |
dataset_id = None | |
if main_selection: | |
sub_options = data[main_selection]["sub_options"] | |
sub_selection = st.selectbox(f"Select Specific {imagery_base} Dataset ID", list(sub_options.keys())) | |
if sub_selection: | |
st.write(f"You selected: {main_selection} -> {sub_options[sub_selection]}") | |
st.write(f"Dataset ID: {sub_selection}") | |
dataset_id = sub_selection | |
st.markdown("<hr><h5><b>Earth Engine Index Calculator</b></h5>", unsafe_allow_html=True) | |
if main_selection and sub_selection: | |
dataset_bands = data[main_selection]["bands"].get(sub_selection, []) | |
st.write(f"Available Bands for {sub_options[sub_selection]}: {', '.join(dataset_bands)}") | |
selected_bands = st.multiselect( | |
"Select 1 or 2 Bands for Calculation", | |
options=dataset_bands, | |
default=[dataset_bands[0]] if dataset_bands else [], | |
help=f"Select 1 or 2 bands from: {', '.join(dataset_bands)}" | |
) | |
if len(selected_bands) < 1: | |
st.warning("Please select at least one band.") | |
st.stop() | |
if selected_bands: | |
if len(selected_bands) == 1: | |
default_formula = f"{selected_bands[0]}" | |
example = f"'{selected_bands[0]} * 2' or '{selected_bands[0]} + 1'" | |
else: | |
default_formula = f"({selected_bands[0]} - {selected_bands[1]}) / ({selected_bands[0]} + {selected_bands[1]})" | |
example = f"'{selected_bands[0]} * {selected_bands[1]} / 2' or '({selected_bands[0]} - {selected_bands[1]}) / ({selected_bands[0]} + {selected_bands[1]})'" | |
custom_formula = st.text_input( | |
"Enter Custom Formula (e.g (B8 - B4) / (B8 + B4) , B4*B3/2)", | |
value=default_formula, | |
help=f"Use only these bands: {', '.join(selected_bands)}. Examples: {example}" | |
) | |
def validate_formula(formula, selected_bands): | |
allowed_chars = set(" +-*/()0123456789.") | |
terms = re.findall(r'[a-zA-Z][a-zA-Z0-9_]*', formula) | |
invalid_terms = [term for term in terms if term not in selected_bands] | |
if invalid_terms: | |
return False, f"Invalid terms in formula: {', '.join(invalid_terms)}. Use only {', '.join(selected_bands)}." | |
if not all(char in allowed_chars or char in ''.join(selected_bands) for char in formula): | |
return False, "Formula contains invalid characters. Use only bands, numbers, and operators (+, -, *, /, ())" | |
return True, "" | |
is_valid, error_message = validate_formula(custom_formula, selected_bands) | |
if not is_valid: | |
st.error(error_message) | |
st.stop() | |
elif not custom_formula: | |
st.warning("Please enter a custom formula to proceed.") | |
st.stop() | |
st.write(f"Custom Formula: {custom_formula}") | |
reducer_choice = st.selectbox( | |
"Select Reducer (e.g, mean , sum , median , min , max , count)", | |
['mean', 'sum', 'median', 'min', 'max', 'count'], | |
index=0 | |
) | |
start_date = st.date_input("Start Date", value=pd.to_datetime('2024-11-01')) | |
end_date = st.date_input("End Date", value=pd.to_datetime('2024-12-01')) | |
start_date_str = start_date.strftime('%Y-%m-%d') | |
end_date_str = end_date.strftime('%Y-%m-%d') | |
st.markdown("<h5>Cloud Filtering</h5>", unsafe_allow_html=True) | |
cloud_threshold = st.slider( | |
"Select Maximum Cloud Coverage Threshold (%)", | |
min_value=0, | |
max_value=50, | |
value=20, | |
step=5, | |
help="Tiles with cloud coverage exceeding this threshold will be excluded. Individual cloudy pixels will also be masked." | |
) | |
aggregation_period = st.selectbox( | |
"Select Aggregation Period (e.g, Custom(Start Date to End Date) , Daily , Weekly , Monthly , Yearly)", | |
["Custom (Start Date to End Date)", "Daily", "Weekly", "Monthly", "Yearly"], | |
index=0 | |
) | |
shape_type = st.selectbox("Do you want to process 'Point' or 'Polygon' data?", ["Point", "Polygon"]) | |
kernel_size = None | |
include_boundary = None | |
if shape_type.lower() == "point": | |
kernel_size = st.selectbox( | |
"Select Calculation Area(e.g, Point , 3x3 Kernel , 5x5 Kernel)", | |
["Point", "3x3 Kernel", "5x5 Kernel"], | |
index=0, | |
help="Choose 'Point' for exact point calculation, or a kernel size for area averaging." | |
) | |
elif shape_type.lower() == "polygon": | |
include_boundary = st.checkbox( | |
"Include Boundary Pixels", | |
value=True, | |
help="Check to include pixels on the polygon boundary; uncheck to exclude them." | |
) | |
file_upload = st.file_uploader(f"Upload your {shape_type} data (CSV, GeoJSON, KML)", type=["csv", "geojson", "kml"]) | |
locations_df = pd.DataFrame() | |
original_lat_col = None | |
original_lon_col = None | |
if file_upload is not None: | |
if shape_type.lower() == "point": | |
if file_upload.name.endswith('.csv'): | |
# Read the CSV file | |
locations_df = pd.read_csv(file_upload) | |
# Show the first few rows to help user identify columns | |
st.write("Preview of your uploaded data (first 5 rows):") | |
st.dataframe(locations_df.head()) | |
# Get all column names from the uploaded file | |
all_columns = locations_df.columns.tolist() | |
# Let user select latitude and longitude columns from dropdown | |
col1, col2 = st.columns(2) | |
with col1: | |
original_lat_col = st.selectbox( | |
"Select Latitude Column", | |
options=all_columns, | |
index=all_columns.index('latitude') if 'latitude' in all_columns else 0, | |
help="Select the column containing latitude values" | |
) | |
with col2: | |
original_lon_col = st.selectbox( | |
"Select Longitude Column", | |
options=all_columns, | |
index=all_columns.index('longitude') if 'longitude' in all_columns else 0, | |
help="Select the column containing longitude values" | |
) | |
# Validate the selected columns contain numeric data | |
if not pd.api.types.is_numeric_dtype(locations_df[original_lat_col]) or not pd.api.types.is_numeric_dtype(locations_df[original_lon_col]): | |
st.error("Error: Selected Latitude and Longitude columns must contain numeric values") | |
st.stop() | |
# Rename the selected columns to standard names for processing | |
locations_df = locations_df.rename(columns={ | |
original_lat_col: 'latitude', | |
original_lon_col: 'longitude' | |
}) | |
elif file_upload.name.endswith('.geojson'): | |
locations_df = gpd.read_file(file_upload) | |
if 'geometry' in locations_df.columns: | |
locations_df['latitude'] = locations_df['geometry'].y | |
locations_df['longitude'] = locations_df['geometry'].x | |
original_lat_col = 'latitude' | |
original_lon_col = 'longitude' | |
else: | |
st.error("GeoJSON file doesn't contain geometry column") | |
st.stop() | |
elif file_upload.name.endswith('.kml'): | |
kml_string = file_upload.read().decode('utf-8') | |
try: | |
root = XET.fromstring(kml_string) | |
ns = {'kml': 'http://www.opengis.net/kml/2.2'} | |
points = [] | |
for placemark in root.findall('.//kml:Placemark', ns): | |
name = placemark.findtext('kml:name', default=f"Point_{len(points)}", namespaces=ns) | |
coords_elem = placemark.find('.//kml:Point/kml:coordinates', ns) | |
if coords_elem is not None: | |
coords_text = coords_elem.text.strip() | |
coords = [c.strip() for c in coords_text.split(',')] | |
if len(coords) >= 2: | |
lon, lat = float(coords[0]), float(coords[1]) | |
points.append({'name': name, 'geometry': f"POINT ({lon} {lat})"}) | |
if not points: | |
st.error("No valid Point data found in the KML file.") | |
else: | |
locations_df = gpd.GeoDataFrame(points, geometry=gpd.GeoSeries.from_wkt([p['geometry'] for p in points]), crs="EPSG:4326") | |
locations_df['latitude'] = locations_df['geometry'].y | |
locations_df['longitude'] = locations_df['geometry'].x | |
original_lat_col = 'latitude' | |
original_lon_col = 'longitude' | |
except Exception as e: | |
st.error(f"Error parsing KML file: {str(e)}") | |
# Display map for points if we have valid data | |
if not locations_df.empty and 'latitude' in locations_df.columns and 'longitude' in locations_df.columns: | |
m = leafmap.Map(center=[locations_df['latitude'].mean(), locations_df['longitude'].mean()], zoom=10) | |
for _, row in locations_df.iterrows(): | |
latitude = row['latitude'] | |
longitude = row['longitude'] | |
if pd.isna(latitude) or pd.isna(longitude): | |
continue | |
m.add_marker(location=[latitude, longitude], popup=row.get('name', 'No Name')) | |
st.write("Map of Uploaded Points:") | |
m.to_streamlit() | |
elif shape_type.lower() == "polygon": | |
if file_upload.name.endswith('.csv'): | |
st.error("CSV upload not supported for polygons. Please upload a GeoJSON or KML file.") | |
elif file_upload.name.endswith('.geojson'): | |
locations_df = gpd.read_file(file_upload) | |
if 'geometry' not in locations_df.columns: | |
st.error("GeoJSON file doesn't contain geometry column") | |
st.stop() | |
elif file_upload.name.endswith('.kml'): | |
kml_string = file_upload.read().decode('utf-8') | |
try: | |
root = XET.fromstring(kml_string) | |
ns = {'kml': 'http://www.opengis.net/kml/2.2'} | |
polygons = [] | |
for placemark in root.findall('.//kml:Placemark', ns): | |
name = placemark.findtext('kml:name', default=f"Polygon_{len(polygons)}", namespaces=ns) | |
coords_elem = placemark.find('.//kml:Polygon//kml:coordinates', ns) | |
if coords_elem is not None: | |
coords_text = ' '.join(coords_elem.text.split()) | |
coord_pairs = [pair.split(',')[:2] for pair in coords_text.split() if pair] | |
if len(coord_pairs) >= 4: | |
coords_str = " ".join([f"{float(lon)} {float(lat)}" for lon, lat in coord_pairs]) | |
polygons.append({'name': name, 'geometry': f"POLYGON (({coords_str}))"}) | |
if not polygons: | |
st.error("No valid Polygon data found in the KML file.") | |
else: | |
locations_df = gpd.GeoDataFrame(polygons, geometry=gpd.GeoSeries.from_wkt([p['geometry'] for p in polygons]), crs="EPSG:4326") | |
except Exception as e: | |
st.error(f"Error parsing KML file: {str(e)}") | |
# Display map for polygons if we have valid data | |
if not locations_df.empty and 'geometry' in locations_df.columns: | |
centroid_lat = locations_df.geometry.centroid.y.mean() | |
centroid_lon = locations_df.geometry.centroid.x.mean() | |
m = leafmap.Map(center=[centroid_lat, centroid_lon], zoom=10) | |
for _, row in locations_df.iterrows(): | |
polygon = row['geometry'] | |
if polygon.is_valid: | |
gdf = gpd.GeoDataFrame([row], geometry=[polygon], crs=locations_df.crs) | |
m.add_gdf(gdf=gdf, layer_name=row.get('name', 'Unnamed Polygon')) | |
st.write("Map of Uploaded Polygons:") | |
m.to_streamlit() | |
if st.button(f"Calculate {custom_formula}"): | |
if not locations_df.empty: | |
with st.spinner("Processing Data..."): | |
try: | |
results, processing_time = process_aggregation( | |
locations_df, | |
start_date_str, | |
end_date_str, | |
dataset_id, | |
selected_bands, | |
reducer_choice, | |
shape_type, | |
aggregation_period, | |
original_lat_col, | |
original_lon_col, | |
custom_formula, | |
kernel_size, | |
include_boundary, | |
cloud_threshold=cloud_threshold | |
) | |
if results: | |
result_df = pd.DataFrame(results) | |
st.write(f"Processed Results Table ({aggregation_period}) for Formula: {custom_formula}") | |
st.dataframe(result_df) | |
filename = f"{main_selection}_{dataset_id}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{aggregation_period.lower()}.csv" | |
st.download_button( | |
label="Download results as CSV", | |
data=result_df.to_csv(index=False).encode('utf-8'), | |
file_name=filename, | |
mime='text/csv' | |
) | |
st.success(f"Processing complete! Total processing time: {processing_time:.2f} seconds.") | |
# Graph Visualization Section | |
st.markdown("<h5>Graph Visualization</h5>", unsafe_allow_html=True) | |
# Dynamically identify the time column | |
if aggregation_period.lower() == 'custom (start date to end date)': | |
x_column = 'Date Range' | |
elif 'Date' in result_df.columns: | |
x_column = 'Date' | |
elif 'Week' in result_df.columns: | |
x_column = 'Week' | |
elif 'Month' in result_df.columns: | |
x_column = 'Month' | |
elif 'Year' in result_df.columns: | |
x_column = 'Year' | |
else: | |
st.warning("No valid time column found for plotting.") | |
st.stop() | |
y_column = 'Calculated Value' | |
# Line Chart | |
st.subheader("Line Chart") | |
st.line_chart(result_df.set_index(x_column)[y_column]) | |
# Bar Chart | |
st.subheader("Bar Chart") | |
st.bar_chart(result_df.set_index(x_column)[y_column]) | |
# Advanced Plot (Plotly) | |
st.subheader("Advanced Interactive Plot (Plotly)") | |
fig = px.line( | |
result_df, | |
x=x_column, | |
y=y_column, | |
color='Location Name', | |
title=f"{custom_formula} Over Time" | |
) | |
st.plotly_chart(fig) | |
else: | |
st.warning("No results were generated. Check your inputs or formula.") | |
st.info(f"Total processing time: {processing_time:.2f} seconds.") | |
except Exception as e: | |
st.error(f"An error occurred during processing: {str(e)}") | |
else: | |
st.warning("Please upload a valid file to proceed.") | |
# if st.button(f"Calculate {custom_formula}"): | |
# if not locations_df.empty: | |
# with st.spinner("Processing Data..."): | |
# try: | |
# results, processing_time = process_aggregation( | |
# locations_df, | |
# start_date_str, | |
# end_date_str, | |
# dataset_id, | |
# selected_bands, | |
# reducer_choice, | |
# shape_type, | |
# aggregation_period, | |
# original_lat_col, | |
# original_lon_col, | |
# custom_formula, | |
# kernel_size, | |
# include_boundary, | |
# cloud_threshold=cloud_threshold | |
# ) | |
# if results: | |
# result_df = pd.DataFrame(results) | |
# st.write(f"Processed Results Table ({aggregation_period}) for Formula: {custom_formula}") | |
# st.dataframe(result_df) | |
# # Debug: Print column names to verify | |
# st.write("Available columns in results:", result_df.columns.tolist()) | |
# filename = f"{main_selection}_{dataset_id}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{aggregation_period.lower()}.csv" | |
# st.download_button( | |
# label="Download results as CSV", | |
# data=result_df.to_csv(index=False).encode('utf-8'), | |
# file_name=filename, | |
# mime='text/csv' | |
# ) | |
# st.success(f"Processing complete! Total processing time: {processing_time:.2f} seconds.") | |
# # Graph Visualization Section | |
# st.markdown("<h5>Graph Visualization</h5>", unsafe_allow_html=True) | |
# # Dynamically identify the value column (handle both 'Calculated Value' and 'Aggregated Value') | |
# value_column = None | |
# if 'Calculated Value' in result_df.columns: | |
# value_column = 'Calculated Value' | |
# elif 'Aggregated Value' in result_df.columns: | |
# value_column = 'Aggregated Value' | |
# else: | |
# st.warning("No value column found for plotting. Available columns: " + ", ".join(result_df.columns)) | |
# st.stop() | |
# # Dynamically identify the time column | |
# if aggregation_period.lower() == 'custom (start date to end date)': | |
# x_column = 'Date Range' | |
# elif 'Date' in result_df.columns: | |
# x_column = 'Date' | |
# elif 'Week' in result_df.columns: | |
# x_column = 'Week' | |
# elif 'Month' in result_df.columns: | |
# x_column = 'Month' | |
# elif 'Year' in result_df.columns: | |
# x_column = 'Year' | |
# else: | |
# st.warning("No valid time column found for plotting. Available columns: " + ", ".join(result_df.columns)) | |
# st.stop() | |
# # Ensure we have valid data to plot | |
# if result_df.empty: | |
# st.warning("No data available for plotting.") | |
# st.stop() | |
# # Line Chart | |
# try: | |
# st.subheader("Line Chart") | |
# st.line_chart(result_df.set_index(x_column)[value_column]) | |
# except Exception as e: | |
# st.error(f"Error creating line chart: {str(e)}") | |
# # Bar Chart | |
# try: | |
# st.subheader("Bar Chart") | |
# st.bar_chart(result_df.set_index(x_column)[value_column]) | |
# except Exception as e: | |
# st.error(f"Error creating bar chart: {str(e)}") | |
# # Advanced Plot (Plotly) | |
# try: | |
# st.subheader("Advanced Interactive Plot (Plotly)") | |
# fig = px.line( | |
# result_df, | |
# x=x_column, | |
# y=value_column, | |
# color='Location Name' if 'Location Name' in result_df.columns else None, | |
# title=f"{custom_formula} Over Time" | |
# ) | |
# st.plotly_chart(fig) | |
# except Exception as e: | |
# st.error(f"Error creating interactive plot: {str(e)}") | |
# else: | |
# st.warning("No results were generated. Check your inputs or formula.") | |
# st.info(f"Total processing time: {processing_time:.2f} seconds.") | |
# except Exception as e: | |
# st.error(f"An error occurred during processing: {str(e)}") | |
# else: | |
# st.warning("Please upload a valid file to proceed.") | |
# if st.button(f"Calculate {custom_formula}"): | |
# if not locations_df.empty: | |
# with st.spinner("Processing Data..."): | |
# try: | |
# results, processing_time = process_aggregation( | |
# locations_df, | |
# start_date_str, | |
# end_date_str, | |
# dataset_id, | |
# selected_bands, | |
# reducer_choice, | |
# shape_type, | |
# aggregation_period, | |
# original_lat_col, | |
# original_lon_col, | |
# custom_formula, | |
# kernel_size, | |
# include_boundary, | |
# cloud_threshold=cloud_threshold | |
# ) | |
# if results: | |
# result_df = pd.DataFrame(results) | |
# # Reorder and rename columns | |
# column_mapping = { | |
# 'Location Name': 'Location Name', | |
# 'Start Date': 'Start Date', | |
# 'End Date': 'End Date', | |
# 'Date Range': 'Date Range', | |
# original_lat_col: 'Latitude', | |
# original_lon_col: 'Longitude', | |
# 'Aggregated Value': 'Calculated Value', | |
# 'Calculated Value': 'Calculated Value' | |
# } | |
# # Keep only columns that exist in the results | |
# available_columns = [col for col in column_mapping.keys() if col in result_df.columns] | |
# result_df = result_df[available_columns] | |
# result_df = result_df.rename(columns={k:v for k,v in column_mapping.items() if k in available_columns}) | |
# st.write(f"Processed Results Table ({aggregation_period}) for Formula: {custom_formula}") | |
# st.dataframe(result_df) | |
# # Graph Visualization Section | |
# st.markdown("<h5>Graph Visualization</h5>", unsafe_allow_html=True) | |
# # Determine time column based on aggregation period | |
# time_column_map = { | |
# 'custom (start date to end date)': 'Date Range', | |
# 'daily': 'Date', | |
# 'weekly': 'Week', | |
# 'monthly': 'Month', | |
# 'yearly': 'Year' | |
# } | |
# x_column = time_column_map.get(aggregation_period.lower()) | |
# if x_column not in result_df.columns: | |
# # Try to find any time-related column | |
# time_columns = ['Date Range', 'Date', 'Week', 'Month', 'Year', 'day', 'month', 'year'] | |
# x_column = next((col for col in time_columns if col in result_df.columns), None) | |
# if x_column is None: | |
# st.warning("No time column found for plotting. Showing data without time axis.") | |
# x_column = 'Location Name' | |
# value_column = 'Calculated Value' | |
# if value_column not in result_df.columns: | |
# st.error("No calculated values found for plotting.") | |
# st.stop() | |
# # Line Chart | |
# try: | |
# st.subheader("Line Chart") | |
# if x_column == 'Location Name': | |
# st.line_chart(result_df.set_index(x_column)[value_column]) | |
# else: | |
# # Convert to datetime for better sorting | |
# result_df[x_column] = pd.to_datetime(result_df[x_column], errors='ignore') | |
# result_df = result_df.sort_values(x_column) | |
# st.line_chart(result_df.set_index(x_column)[value_column]) | |
# except Exception as e: | |
# st.error(f"Error creating line chart: {str(e)}") | |
# # Bar Chart | |
# try: | |
# st.subheader("Bar Chart") | |
# if x_column == 'Location Name': | |
# st.bar_chart(result_df.set_index(x_column)[value_column]) | |
# else: | |
# result_df[x_column] = pd.to_datetime(result_df[x_column], errors='ignore') | |
# result_df = result_df.sort_values(x_column) | |
# st.bar_chart(result_df.set_index(x_column)[value_column]) | |
# except Exception as e: | |
# st.error(f"Error creating bar chart: {str(e)}") | |
# # Advanced Plot (Plotly) | |
# try: | |
# st.subheader("Advanced Interactive Plot (Plotly)") | |
# if x_column == 'Location Name': | |
# fig = px.bar( | |
# result_df, | |
# x=x_column, | |
# y=value_column, | |
# color='Location Name', | |
# title=f"{custom_formula} by Location" | |
# ) | |
# else: | |
# fig = px.line( | |
# result_df, | |
# x=x_column, | |
# y=value_column, | |
# color='Location Name', | |
# title=f"{custom_formula} Over Time" | |
# ) | |
# st.plotly_chart(fig) | |
# except Exception as e: | |
# st.error(f"Error creating interactive plot: {str(e)}") | |
# # Download button | |
# filename = f"{main_selection}_{dataset_id}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{aggregation_period.lower()}.csv" | |
# st.download_button( | |
# label="Download results as CSV", | |
# data=result_df.to_csv(index=False).encode('utf-8'), | |
# file_name=filename, | |
# mime='text/csv' | |
# ) | |
# st.success(f"Processing complete! Total processing time: {processing_time:.2f} seconds.") | |
# else: | |
# st.warning("No results were generated. Check your inputs or formula.") | |
# st.info(f"Total processing time: {processing_time:.2f} seconds.") | |
# except Exception as e: | |
# st.error(f"An error occurred during processing: {str(e)}") | |
# else: | |
# st.warning("Please upload a valid file to proceed.") |