dvpi / app.py
KennethTM's picture
Update app.py
843b1b0 verified
from fastapi import FastAPI
from pydantic import BaseModel, Field
import numpy as np
import onnxruntime as ort
from typing_extensions import Annotated
import gradio as gr
from cryptography.fernet import Fernet
import os
import pickle as pkl
# Model load
key = os.getenv("ONNX_KEY")
cipher = Fernet(key)
VERSION = "0.0.3"
TITLE = f"DVPI beregnings API (version {VERSION})"
DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.017 EQR-enheder** og **R<sup>2</sup>=0.98** når den sammenlignes med den originale. Kan der ikke beregnes en værdi, returneres EQR=0 og DVPI=0."
URL = "https://kennethtm-dvpi.hf.space"
# Load ONNX model and species mappings
with open("model_v3.bin", "rb") as f:
encrypted = f.read()
decrypted = cipher.decrypt(encrypted)
ort_session = ort.InferenceSession(decrypted)
# Load metadata
with open("metadata_v3.bin", "rb") as f:
encrypted = f.read()
decrypted = cipher.decrypt(encrypted)
metadata = pkl.loads(decrypted)
latinname2stancode = metadata["latinname2stancode"]
valid_taxacodes = metadata["valid_taxacodes"]
normalizer_1 = metadata["normalizer_1"]
normalizer_2 = metadata["normalizer_2"]
taxacode2idx = metadata["taxacode2idx"]
# Preprocess species
def preprocess_species(species: dict[int: float]) -> dict[int: float]:
# Apply filter 1
intermediate_species = {}
for sccode, value in species.items():
if sccode in normalizer_1:
new_sccode = normalizer_1[sccode]
if new_sccode in intermediate_species:
intermediate_species[new_sccode] += value
else:
intermediate_species[new_sccode] = value
# Apply filter 2
final_species = {}
for sccode, value in intermediate_species.items():
if sccode in normalizer_2:
if normalizer_2[sccode] is not None:
new_sccode = normalizer_2[sccode]
if new_sccode in final_species:
final_species[new_sccode] += value
else:
final_species[new_sccode] = value
else:
final_species[sccode] = value
# filter valid taxacodes
final_species = {taxacode: value for taxacode, value in final_species.items() if taxacode in valid_taxacodes}
return final_species
class SpeciesCover(BaseModel):
species: dict[int, Annotated[float, Field(ge=0, le=100)]]
model_config = {
"json_schema_extra": {
"examples": [{
"species": {
6458: 25.0,
4158: 15.5,
7208: 10.0
}
}]
}
}
class EQRResult(BaseModel):
EQR: float
DVPI: int
version: str = VERSION
# Create FastAPI app
app = FastAPI(title=TITLE,
description=DESCRIPTION)
def eqr_to_dvpi(eqr: float) -> int:
if eqr < 0.20:
return 1
elif eqr < 0.35:
return 2
elif eqr < 0.50:
return 3
elif eqr < 0.70:
return 4
else:
return 5
# FastAPI routes
@app.post("/dvpi")
def predict(cover_data: SpeciesCover) -> EQRResult:
"""Predict EQR and DVPI from species cover data"""
species_preproc = preprocess_species(cover_data.species)
input_vector = np.zeros((1, len(valid_taxacodes)))
for species, cover in species_preproc.items():
idx = taxacode2idx[species]
input_vector[0, idx] = cover
if np.sum(input_vector) == 0:
return EQRResult(EQR=0, DVPI=0)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: input_vector.astype(np.float32)}
_, output_2 = ort_session.run(None, ort_inputs)
eqr = float(output_2[0][0])
eqr = 1 if eqr > 1 else eqr
dvpi = eqr_to_dvpi(eqr)
return EQRResult(EQR=round(eqr, 3), DVPI=dvpi)
# Gradio app
def add_entry(species, cover, current_dict) -> tuple[dict, str]:
current_dict[species] = cover
return current_dict, current_dict
def gradio_predict(cover_data: dict):
if len(cover_data) == 0:
return {}
cover_data_code = {latinname2stancode[species]: cover for species, cover in cover_data.items()}
data = SpeciesCover(species=cover_data_code)
result = predict(data)
return result.model_dump()
with gr.Blocks() as io:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Tab(label = "Beregner"):
gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Når API'et bruges anvendes arternes [Stancode](https://dce.au.dk/overvaagning/stancode/stancodelister) (SC1064) - se 'Dokumentation' for eksempel på brug.")
current_dict = gr.State({})
with gr.Row():
species_choices = sorted(list(latinname2stancode.keys()))
species_input = gr.Dropdown(choices=species_choices, label="Vælg art")
cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100)
with gr.Row():
add_btn = gr.Button("Tilføj")
reset_btn = gr.Button("Nulstil")
list_display = gr.JSON(label="Artsliste")
calc_btn = gr.Button("Beregn")
results = gr.JSON(label="Resultater")
def reset_dict():
return {}, {}, {}
add_btn.click(
add_entry,
inputs=[species_input, cover_input, current_dict],
outputs=[current_dict, list_display],
show_api=False
)
reset_btn.click(
reset_dict,
inputs=[],
outputs=[current_dict, list_display, results],
show_api=False
)
calc_btn.click(
gradio_predict,
inputs=[current_dict],
outputs=results,
show_api=False
)
gr.Markdown("App og model af Kenneth Thorø Martinsen ([email protected]).")
with gr.Tab(label="Dokumentation"):
gr.Markdown("## Eksempel på brug af API")
gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)")
gr.Markdown("### Python")
gr.Code(f"""
import requests
import json
data = {{
"species": {{
6458: 25.0,
4158: 15.5,
7208: 10.0
}}
}}
response = requests.post("{URL}/dvpi", json=data)
print(response.json())
""")
gr.Markdown("### R")
gr.Code(f"""
library(httr)
library(jsonlite)
data <- list(species = list(
6458 = 25.0,
4158 = 15.5,
7208 = 10.0
))
response <- POST("{URL}/dvpi",
body = toJSON(data, auto_unbox = TRUE),
content_type("application/json"))
print(fromJSON(rawToChar(response$content)))
""")
# Mount Gradio app
app = gr.mount_gradio_app(app, io, path="/")