|
from fastapi import FastAPI |
|
from pydantic import BaseModel, Field |
|
import numpy as np |
|
import onnxruntime as ort |
|
from typing_extensions import Annotated |
|
import gradio as gr |
|
from cryptography.fernet import Fernet |
|
import os |
|
import pickle as pkl |
|
|
|
|
|
key = os.getenv("ONNX_KEY") |
|
cipher = Fernet(key) |
|
|
|
VERSION = "0.0.3" |
|
TITLE = f"DVPI beregnings API (version {VERSION})" |
|
DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.017 EQR-enheder** og **R<sup>2</sup>=0.98** når den sammenlignes med den originale. Kan der ikke beregnes en værdi, returneres EQR=0 og DVPI=0." |
|
URL = "https://kennethtm-dvpi.hf.space" |
|
|
|
|
|
with open("model_v3.bin", "rb") as f: |
|
encrypted = f.read() |
|
decrypted = cipher.decrypt(encrypted) |
|
ort_session = ort.InferenceSession(decrypted) |
|
|
|
|
|
with open("metadata_v3.bin", "rb") as f: |
|
encrypted = f.read() |
|
decrypted = cipher.decrypt(encrypted) |
|
metadata = pkl.loads(decrypted) |
|
|
|
latinname2stancode = metadata["latinname2stancode"] |
|
valid_taxacodes = metadata["valid_taxacodes"] |
|
normalizer_1 = metadata["normalizer_1"] |
|
normalizer_2 = metadata["normalizer_2"] |
|
taxacode2idx = metadata["taxacode2idx"] |
|
|
|
|
|
def preprocess_species(species: dict[int: float]) -> dict[int: float]: |
|
|
|
intermediate_species = {} |
|
for sccode, value in species.items(): |
|
if sccode in normalizer_1: |
|
new_sccode = normalizer_1[sccode] |
|
if new_sccode in intermediate_species: |
|
intermediate_species[new_sccode] += value |
|
else: |
|
intermediate_species[new_sccode] = value |
|
|
|
|
|
final_species = {} |
|
for sccode, value in intermediate_species.items(): |
|
if sccode in normalizer_2: |
|
if normalizer_2[sccode] is not None: |
|
new_sccode = normalizer_2[sccode] |
|
if new_sccode in final_species: |
|
final_species[new_sccode] += value |
|
else: |
|
final_species[new_sccode] = value |
|
else: |
|
final_species[sccode] = value |
|
|
|
|
|
final_species = {taxacode: value for taxacode, value in final_species.items() if taxacode in valid_taxacodes} |
|
|
|
return final_species |
|
|
|
|
|
class SpeciesCover(BaseModel): |
|
species: dict[int, Annotated[float, Field(ge=0, le=100)]] |
|
|
|
model_config = { |
|
"json_schema_extra": { |
|
"examples": [{ |
|
"species": { |
|
6458: 25.0, |
|
4158: 15.5, |
|
7208: 10.0 |
|
} |
|
}] |
|
} |
|
} |
|
|
|
class EQRResult(BaseModel): |
|
EQR: float |
|
DVPI: int |
|
version: str = VERSION |
|
|
|
|
|
app = FastAPI(title=TITLE, |
|
description=DESCRIPTION) |
|
|
|
def eqr_to_dvpi(eqr: float) -> int: |
|
if eqr < 0.20: |
|
return 1 |
|
elif eqr < 0.35: |
|
return 2 |
|
elif eqr < 0.50: |
|
return 3 |
|
elif eqr < 0.70: |
|
return 4 |
|
else: |
|
return 5 |
|
|
|
|
|
|
|
@app.post("/dvpi") |
|
def predict(cover_data: SpeciesCover) -> EQRResult: |
|
"""Predict EQR and DVPI from species cover data""" |
|
|
|
species_preproc = preprocess_species(cover_data.species) |
|
|
|
input_vector = np.zeros((1, len(valid_taxacodes))) |
|
|
|
for species, cover in species_preproc.items(): |
|
idx = taxacode2idx[species] |
|
input_vector[0, idx] = cover |
|
|
|
if np.sum(input_vector) == 0: |
|
return EQRResult(EQR=0, DVPI=0) |
|
|
|
input_name = ort_session.get_inputs()[0].name |
|
ort_inputs = {input_name: input_vector.astype(np.float32)} |
|
_, output_2 = ort_session.run(None, ort_inputs) |
|
|
|
eqr = float(output_2[0][0]) |
|
eqr = 1 if eqr > 1 else eqr |
|
dvpi = eqr_to_dvpi(eqr) |
|
|
|
return EQRResult(EQR=round(eqr, 3), DVPI=dvpi) |
|
|
|
|
|
def add_entry(species, cover, current_dict) -> tuple[dict, str]: |
|
|
|
current_dict[species] = cover |
|
return current_dict, current_dict |
|
|
|
def gradio_predict(cover_data: dict): |
|
|
|
if len(cover_data) == 0: |
|
return {} |
|
|
|
cover_data_code = {latinname2stancode[species]: cover for species, cover in cover_data.items()} |
|
|
|
data = SpeciesCover(species=cover_data_code) |
|
result = predict(data) |
|
|
|
return result.model_dump() |
|
|
|
with gr.Blocks() as io: |
|
|
|
gr.Markdown(f"# {TITLE}") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Tab(label = "Beregner"): |
|
|
|
gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Når API'et bruges anvendes arternes [Stancode](https://dce.au.dk/overvaagning/stancode/stancodelister) (SC1064) - se 'Dokumentation' for eksempel på brug.") |
|
|
|
current_dict = gr.State({}) |
|
|
|
with gr.Row(): |
|
species_choices = sorted(list(latinname2stancode.keys())) |
|
species_input = gr.Dropdown(choices=species_choices, label="Vælg art") |
|
cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100) |
|
|
|
with gr.Row(): |
|
add_btn = gr.Button("Tilføj") |
|
reset_btn = gr.Button("Nulstil") |
|
|
|
list_display = gr.JSON(label="Artsliste") |
|
|
|
calc_btn = gr.Button("Beregn") |
|
results = gr.JSON(label="Resultater") |
|
|
|
def reset_dict(): |
|
return {}, {}, {} |
|
|
|
add_btn.click( |
|
add_entry, |
|
inputs=[species_input, cover_input, current_dict], |
|
outputs=[current_dict, list_display], |
|
show_api=False |
|
) |
|
|
|
reset_btn.click( |
|
reset_dict, |
|
inputs=[], |
|
outputs=[current_dict, list_display, results], |
|
show_api=False |
|
) |
|
|
|
calc_btn.click( |
|
gradio_predict, |
|
inputs=[current_dict], |
|
outputs=results, |
|
show_api=False |
|
) |
|
|
|
gr.Markdown("App og model af Kenneth Thorø Martinsen ([email protected]).") |
|
|
|
with gr.Tab(label="Dokumentation"): |
|
|
|
gr.Markdown("## Eksempel på brug af API") |
|
gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)") |
|
gr.Markdown("### Python") |
|
gr.Code(f""" |
|
import requests |
|
import json |
|
|
|
data = {{ |
|
"species": {{ |
|
6458: 25.0, |
|
4158: 15.5, |
|
7208: 10.0 |
|
}} |
|
}} |
|
|
|
response = requests.post("{URL}/dvpi", json=data) |
|
print(response.json()) |
|
""") |
|
|
|
gr.Markdown("### R") |
|
gr.Code(f""" |
|
library(httr) |
|
library(jsonlite) |
|
|
|
data <- list(species = list( |
|
6458 = 25.0, |
|
4158 = 15.5, |
|
7208 = 10.0 |
|
)) |
|
|
|
response <- POST("{URL}/dvpi", |
|
body = toJSON(data, auto_unbox = TRUE), |
|
content_type("application/json")) |
|
|
|
print(fromJSON(rawToChar(response$content))) |
|
""") |
|
|
|
|
|
app = gr.mount_gradio_app(app, io, path="/") |
|
|