from fastapi import FastAPI from pydantic import BaseModel, Field import numpy as np import onnxruntime as ort from typing_extensions import Annotated import gradio as gr from cryptography.fernet import Fernet import os import pickle as pkl # Model load key = os.getenv("ONNX_KEY") cipher = Fernet(key) VERSION = "0.0.3" TITLE = f"DVPI beregnings API (version {VERSION})" DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.017 EQR-enheder** og **R2=0.98** når den sammenlignes med den originale. Kan der ikke beregnes en værdi, returneres EQR=0 og DVPI=0." URL = "https://kennethtm-dvpi.hf.space" # Load ONNX model and species mappings with open("model_v3.bin", "rb") as f: encrypted = f.read() decrypted = cipher.decrypt(encrypted) ort_session = ort.InferenceSession(decrypted) # Load metadata with open("metadata_v3.bin", "rb") as f: encrypted = f.read() decrypted = cipher.decrypt(encrypted) metadata = pkl.loads(decrypted) latinname2stancode = metadata["latinname2stancode"] valid_taxacodes = metadata["valid_taxacodes"] normalizer_1 = metadata["normalizer_1"] normalizer_2 = metadata["normalizer_2"] taxacode2idx = metadata["taxacode2idx"] # Preprocess species def preprocess_species(species: dict[int: float]) -> dict[int: float]: # Apply filter 1 intermediate_species = {} for sccode, value in species.items(): if sccode in normalizer_1: new_sccode = normalizer_1[sccode] if new_sccode in intermediate_species: intermediate_species[new_sccode] += value else: intermediate_species[new_sccode] = value # Apply filter 2 final_species = {} for sccode, value in intermediate_species.items(): if sccode in normalizer_2: if normalizer_2[sccode] is not None: new_sccode = normalizer_2[sccode] if new_sccode in final_species: final_species[new_sccode] += value else: final_species[new_sccode] = value else: final_species[sccode] = value # filter valid taxacodes final_species = {taxacode: value for taxacode, value in final_species.items() if taxacode in valid_taxacodes} return final_species class SpeciesCover(BaseModel): species: dict[int, Annotated[float, Field(ge=0, le=100)]] model_config = { "json_schema_extra": { "examples": [{ "species": { 6458: 25.0, 4158: 15.5, 7208: 10.0 } }] } } class EQRResult(BaseModel): EQR: float DVPI: int version: str = VERSION # Create FastAPI app app = FastAPI(title=TITLE, description=DESCRIPTION) def eqr_to_dvpi(eqr: float) -> int: if eqr < 0.20: return 1 elif eqr < 0.35: return 2 elif eqr < 0.50: return 3 elif eqr < 0.70: return 4 else: return 5 # FastAPI routes @app.post("/dvpi") def predict(cover_data: SpeciesCover) -> EQRResult: """Predict EQR and DVPI from species cover data""" species_preproc = preprocess_species(cover_data.species) input_vector = np.zeros((1, len(valid_taxacodes))) for species, cover in species_preproc.items(): idx = taxacode2idx[species] input_vector[0, idx] = cover if np.sum(input_vector) == 0: return EQRResult(EQR=0, DVPI=0) input_name = ort_session.get_inputs()[0].name ort_inputs = {input_name: input_vector.astype(np.float32)} _, output_2 = ort_session.run(None, ort_inputs) eqr = float(output_2[0][0]) eqr = 1 if eqr > 1 else eqr dvpi = eqr_to_dvpi(eqr) return EQRResult(EQR=round(eqr, 3), DVPI=dvpi) # Gradio app def add_entry(species, cover, current_dict) -> tuple[dict, str]: current_dict[species] = cover return current_dict, current_dict def gradio_predict(cover_data: dict): if len(cover_data) == 0: return {} cover_data_code = {latinname2stancode[species]: cover for species, cover in cover_data.items()} data = SpeciesCover(species=cover_data_code) result = predict(data) return result.model_dump() with gr.Blocks() as io: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Tab(label = "Beregner"): gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Når API'et bruges anvendes arternes [Stancode](https://dce.au.dk/overvaagning/stancode/stancodelister) (SC1064) - se 'Dokumentation' for eksempel på brug.") current_dict = gr.State({}) with gr.Row(): species_choices = sorted(list(latinname2stancode.keys())) species_input = gr.Dropdown(choices=species_choices, label="Vælg art") cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100) with gr.Row(): add_btn = gr.Button("Tilføj") reset_btn = gr.Button("Nulstil") list_display = gr.JSON(label="Artsliste") calc_btn = gr.Button("Beregn") results = gr.JSON(label="Resultater") def reset_dict(): return {}, {}, {} add_btn.click( add_entry, inputs=[species_input, cover_input, current_dict], outputs=[current_dict, list_display], show_api=False ) reset_btn.click( reset_dict, inputs=[], outputs=[current_dict, list_display, results], show_api=False ) calc_btn.click( gradio_predict, inputs=[current_dict], outputs=results, show_api=False ) gr.Markdown("App og model af Kenneth Thorø Martinsen (kenneth2810@gmail.com).") with gr.Tab(label="Dokumentation"): gr.Markdown("## Eksempel på brug af API") gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)") gr.Markdown("### Python") gr.Code(f""" import requests import json data = {{ "species": {{ 6458: 25.0, 4158: 15.5, 7208: 10.0 }} }} response = requests.post("{URL}/dvpi", json=data) print(response.json()) """) gr.Markdown("### R") gr.Code(f""" library(httr) library(jsonlite) data <- list(species = list( 6458 = 25.0, 4158 = 15.5, 7208 = 10.0 )) response <- POST("{URL}/dvpi", body = toJSON(data, auto_unbox = TRUE), content_type("application/json")) print(fromJSON(rawToChar(response$content))) """) # Mount Gradio app app = gr.mount_gradio_app(app, io, path="/")