|
import streamlit as st |
|
import leafmap.foliumap as leafmap |
|
from transformers import PretrainedConfig |
|
from folium import Icon |
|
import os |
|
|
|
from messis.messis import Messis |
|
from inference import perform_inference |
|
|
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
st.set_page_config(layout="wide") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4') |
|
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4') |
|
return model, config |
|
model, config = load_model() |
|
|
|
def perform_inference_step(): |
|
st.title("Step 2: Perform Crop Classification") |
|
|
|
if "selected_location" not in st.session_state: |
|
st.error("No location selected. Please select a location first.") |
|
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍") |
|
return |
|
|
|
lat, lon = st.session_state["selected_location"] |
|
|
|
|
|
instructions = """ |
|
Click the button "Perform Crop Classification". |
|
|
|
_Note:_ |
|
- Messis will classify the crop types for the fields in your selected location. |
|
- Hover over the fields to see the predicted and true crop type. |
|
- The satellite images might take a few seconds to load. |
|
""" |
|
st.sidebar.header("Instructions") |
|
st.sidebar.markdown(instructions) |
|
|
|
|
|
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False) |
|
|
|
|
|
if st.button("Perform Crop Classification", type="primary"): |
|
predictions = perform_inference(lon, lat, model, config, debug=True) |
|
|
|
m.add_data(predictions, |
|
layer_name = "Predictions", |
|
column="Correct", |
|
add_legend=False, |
|
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25}, |
|
) |
|
st.success("Inference completed!") |
|
|
|
|
|
m.add_tile_layer( |
|
url=os.environ.get("TILE_LAYER_URL"), |
|
name="Sentinel-2 Satellite Imagery", |
|
attribution="Copernicus Sentinel data 2019 / ESA", |
|
) |
|
|
|
|
|
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs") |
|
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon) |
|
|
|
|
|
m.to_streamlit() |
|
|
|
if __name__ == "__main__": |
|
perform_inference_step() |
|
|