File size: 3,411 Bytes
5b24075
 
 
 
 
 
 
63ad23b
5b24075
63ad23b
5b24075
63ad23b
 
5b24075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13efa6d
63ad23b
13efa6d
63ad23b
37d5e3e
5b24075
13efa6d
 
5b24075
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
import leafmap.foliumap as leafmap
from transformers import PretrainedConfig
from folium import Icon

from messis.messis import Messis
from inference import perform_inference
from inference import generate_presigned_url

from dotenv import load_dotenv

load_dotenv()
st.set_page_config(layout="wide")

# Load the model
@st.cache_resource
def load_model():
    config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
    model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
    return model, config
model, config = load_model()

def perform_inference_step():
    st.title("Step 2: Perform Crop Classification")

    if "selected_location" not in st.session_state:
        st.error("No location selected. Please select a location first.")
        st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
        return

    lat, lon = st.session_state["selected_location"]

    # Sidebar
    st.sidebar.header("Settings")

    # Timestep Slider
    timestep = st.sidebar.slider("Select Timestep", 1, 9, 5)

    # Band Dropdown
    band_options = {
        "RGB": [1, 2, 3],  # Adjust indices based on the actual bands in your GeoTIFF
        "NIR": [4],
        "SWIR1": [5],
        "SWIR2": [6]
    }
    vmin_vmax = { 
        "RGB": (89, 1878),
        "NIR": (165, 5468),
        "SWIR1": (120, 3361),
        "SWIR2": (94, 2700)
    }
    selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0)
    
    # Calculate the band indices based on the selected timestep
    selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]]
    
    instructions = """
    Click the button "Perform Crop Classification".

    _Note:_ 
    - Messis will classify the crop types for the fields in your selected location.
    - Hover over the fields to see the predicted and true crop type.
    - The satellite images might take a few seconds to load.
    """
    st.sidebar.header("Instructions")
    st.sidebar.markdown(instructions)

    # Initialize the map
    m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)

    # Perform inference
    if st.sidebar.button("Perform Crop Classification", type="primary"):
        predictions = perform_inference(lon, lat, model, config, debug=True)

        m.add_data(predictions,
            layer_name = "Predictions",
            column="Correct",
            add_legend=False,
            style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
        )
        st.success("Inference completed!")

    # Add COG
    presigned_url = generate_presigned_url('messis-demo', 'stacked_features_cog.tif')
    m.add_cog_layer(
        url=presigned_url,
        name="Sentinel-2 Satellite Imagery",
        bands=selected_bands,
        rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}",
        zoom_to_layer=True
    )

    # Show the POI on the map
    poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
    m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)

    # Display the map in the Streamlit app
    m.to_streamlit()

if __name__ == "__main__":
    perform_inference_step()