messis-demo / pages /2_Perform_Crop_Classification.py
florinbarbisch
Removed: sidebar and option to run locally
672e5a7
import streamlit as st
import leafmap.foliumap as leafmap
from transformers import PretrainedConfig
from folium import Icon
import os
from messis.messis import Messis
from inference import perform_inference
from dotenv import load_dotenv
load_dotenv()
st.set_page_config(layout="wide")
# Load the model
@st.cache_resource
def load_model():
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
return model, config
model, config = load_model()
def perform_inference_step():
st.title("Step 2: Perform Crop Classification")
if "selected_location" not in st.session_state:
st.error("No location selected. Please select a location first.")
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
return
lat, lon = st.session_state["selected_location"]
# Sidebar
instructions = """
Click the button "Perform Crop Classification".
_Note:_
- Messis will classify the crop types for the fields in your selected location.
- Hover over the fields to see the predicted and true crop type.
- The satellite images might take a few seconds to load.
"""
st.sidebar.header("Instructions")
st.sidebar.markdown(instructions)
# Initialize the map
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)
# Perform inference
if st.button("Perform Crop Classification", type="primary"):
predictions = perform_inference(lon, lat, model, config, debug=True)
m.add_data(predictions,
layer_name = "Predictions",
column="Correct",
add_legend=False,
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
)
st.success("Inference completed!")
# Add Satellite Imagery
m.add_tile_layer(
url=os.environ.get("TILE_LAYER_URL"),
name="Sentinel-2 Satellite Imagery",
attribution="Copernicus Sentinel data 2019 / ESA",
)
# Show the POI on the map
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)
# Display the map in the Streamlit app
m.to_streamlit()
if __name__ == "__main__":
perform_inference_step()