|
import streamlit as st |
|
import leafmap.foliumap as leafmap |
|
from transformers import PretrainedConfig |
|
from folium import Icon |
|
|
|
from messis.messis import Messis |
|
from inference import perform_inference |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
GEOTIFF_PATH = "./data/stacked_features.tif" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4') |
|
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4') |
|
return model, config |
|
model, config = load_model() |
|
|
|
def perform_inference_step(): |
|
st.title("Step 2: Perform Crop Classification") |
|
|
|
if "selected_location" not in st.session_state: |
|
st.error("No location selected. Please select a location first.") |
|
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍") |
|
return |
|
|
|
lat, lon = st.session_state["selected_location"] |
|
|
|
|
|
st.sidebar.header("Settings") |
|
|
|
|
|
timestep = st.sidebar.slider("Select Timestep", 1, 9, 5) |
|
|
|
|
|
band_options = { |
|
"RGB": [1, 2, 3], |
|
"NIR": [4], |
|
"SWIR1": [5], |
|
"SWIR2": [6] |
|
} |
|
vmin_vmax = { |
|
"RGB": (89, 1878), |
|
"NIR": (165, 5468), |
|
"SWIR1": (120, 3361), |
|
"SWIR2": (94, 2700) |
|
} |
|
selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0) |
|
|
|
|
|
selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]] |
|
|
|
instructions = """ |
|
Click the button "Perform Crop Classification". |
|
|
|
_Note:_ |
|
- Messis will classify the crop types for the fields in your selected location. |
|
- Hover over the fields to see the predicted and true crop type. |
|
- The satellite images might take a few seconds to load. |
|
""" |
|
st.sidebar.header("Instructions") |
|
st.sidebar.markdown(instructions) |
|
|
|
|
|
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False) |
|
|
|
|
|
if st.sidebar.button("Perform Crop Classification", type="primary"): |
|
predictions = perform_inference(lon, lat, model, config, debug=True) |
|
|
|
m.add_data(predictions, |
|
layer_name = "Predictions", |
|
column="Correct", |
|
add_legend=False, |
|
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25}, |
|
) |
|
st.success("Inference completed!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m.add_cog_layer( |
|
url="https://messis-demo.s3.amazonaws.com/stacked_features_cog.tif", |
|
name="AWS COG", |
|
bands=selected_bands, |
|
rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}", |
|
zoom_to_layer=True |
|
) |
|
|
|
|
|
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs") |
|
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon) |
|
|
|
|
|
m.to_streamlit() |
|
|
|
if __name__ == "__main__": |
|
perform_inference_step() |
|
|