KennethTM commited on
Commit
a4f81c0
·
verified ·
1 Parent(s): 2df923c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -46
app.py CHANGED
@@ -1,53 +1,91 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel, Field
3
- from typing import Literal
4
- import json
5
  import numpy as np
6
  import onnxruntime as ort
7
  from typing_extensions import Annotated
8
  import gradio as gr
 
9
  from cryptography.fernet import Fernet
10
  import os
 
 
 
11
 
12
  # Model load
13
  key = os.getenv("ONNX_KEY")
14
  cipher = Fernet(key)
15
 
16
- VERSION = "0.0.1"
17
  TITLE = f"DVPI beregnings API (version {VERSION})"
18
- DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.05 EQR-enheder**."
19
- URL = "https://kennethtm-dvpi.hf.space"
20
 
21
  # Load ONNX model and species mappings
22
- with open("model.bin", "rb") as f:
23
  encrypted = f.read()
24
  decrypted = cipher.decrypt(encrypted)
25
  ort_session = ort.InferenceSession(decrypted)
26
 
27
- with open("spec2idx.json", "r") as f:
28
- spec2idx = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Define types
31
- valid_species = tuple(spec2idx.keys())
32
 
33
  class SpeciesCover(BaseModel):
34
- species: dict[Literal[valid_species], Annotated[float, Field(ge=0, le=100)]]
35
 
36
  model_config = {
37
  "json_schema_extra": {
38
  "examples": [{
39
  "species": {
40
- "Potamogeton alpinus": 25.0,
41
- "Berula erecta": 15.5,
42
- "Calamagrostis canescens": 10.0
43
  }
44
  }]
45
  }
46
  }
47
 
48
-
49
  class EQRResult(BaseModel):
50
- EQR: float # Round to 2 decimals
51
  DVPI: int
52
  version: str = VERSION
53
 
@@ -67,48 +105,47 @@ def eqr_to_dvpi(eqr: float) -> int:
67
  else:
68
  return 5
69
 
 
70
  # FastAPI routes
71
  @app.post("/dvpi")
72
  def predict(cover_data: SpeciesCover) -> EQRResult:
73
  """Predict EQR and DVPI from species cover data"""
74
- # Initialize input vector with zeros
75
- input_vector = np.zeros((1, len(spec2idx)))
76
 
77
- print(cover_data.species)
 
 
78
 
79
- # Fill values from input
80
- for species, cover in cover_data.species.items():
81
- idx = spec2idx[species]
82
  input_vector[0, idx] = cover
 
 
 
83
 
84
- # Get prediction
85
  input_name = ort_session.get_inputs()[0].name
86
  ort_inputs = {input_name: input_vector.astype(np.float32)}
87
- ort_output = ort_session.run(None, ort_inputs)
88
 
89
- eqr = float(ort_output[0][0])
 
90
  dvpi = eqr_to_dvpi(eqr)
91
 
92
- return EQRResult(EQR=round(eqr, 2), DVPI=dvpi)
93
-
94
- @app.get("/arter")
95
- def list_species() -> dict:
96
- """Return list of valid species names"""
97
- return {"species": list(spec2idx.keys())}
98
 
99
  # Gradio app
100
- def add_entry(species, cover, current_dict) -> tuple[SpeciesCover, str]:
101
 
102
  current_dict[species] = cover
103
-
104
  return current_dict, current_dict
105
 
106
  def gradio_predict(cover_data: dict):
107
 
108
  if len(cover_data) == 0:
109
  return {}
 
 
110
 
111
- data = SpeciesCover(species=cover_data)
112
  result = predict(data)
113
 
114
  return result.model_dump()
@@ -120,12 +157,13 @@ with gr.Blocks() as io:
120
 
121
  with gr.Tab(label = "Beregner"):
122
 
123
- gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Dækningsgraden angives i procent som summen af scoren for dækningsgraden (1-5) divideret med det samlede antal undersøgte kvadrater gange 5, og til sidste konverteret til procent. Eksempel: Potamogeton alpinus findes 3 felter med scorerne 2, 3 og 5 ud af 50 undersøgte kvadrater. Dækningsgraden for Potamogeton alpinus er derfor (2+3+5)/(50*5)*100 = 4%.")
124
 
125
  current_dict = gr.State({})
126
 
127
  with gr.Row():
128
- species_input = gr.Dropdown(choices=valid_species, label="Vælg art")
 
129
  cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100)
130
 
131
  with gr.Row():
@@ -143,26 +181,28 @@ with gr.Blocks() as io:
143
  add_btn.click(
144
  add_entry,
145
  inputs=[species_input, cover_input, current_dict],
146
- outputs=[current_dict, list_display]
 
147
  )
148
 
149
  reset_btn.click(
150
  reset_dict,
151
  inputs=[],
152
- outputs=[current_dict, list_display, results]
 
153
  )
154
 
155
  calc_btn.click(
156
  gradio_predict,
157
  inputs=[current_dict],
158
- outputs=results
 
159
  )
160
 
161
  gr.Markdown("App og model af Kenneth Thorø Martinsen.")
162
 
163
  with gr.Tab(label="Dokumentation"):
164
 
165
- # Add markdown description with code to call the api in python
166
  gr.Markdown("## Eksempel på brug af API")
167
  gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)")
168
  gr.Markdown("### Python")
@@ -172,9 +212,9 @@ import json
172
 
173
  data = {{
174
  "species": {{
175
- "Potamogeton alpinus": 25.0,
176
- "Berula erecta": 15.5,
177
- "Calamagrostis canescens": 10.0
178
  }}
179
  }}
180
 
@@ -188,9 +228,9 @@ library(httr)
188
  library(jsonlite)
189
 
190
  data <- list(species = list(
191
- "Potamogeton alpinus" = 25.0,
192
- "Berula erecta" = 15.5,
193
- "Calamagrostis canescens" = 10.0
194
  ))
195
 
196
  response <- POST("{URL}/dvpi",
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel, Field
 
 
3
  import numpy as np
4
  import onnxruntime as ort
5
  from typing_extensions import Annotated
6
  import gradio as gr
7
+ from dotenv import load_dotenv
8
  from cryptography.fernet import Fernet
9
  import os
10
+ import pickle as pkl
11
+
12
+ load_dotenv()
13
 
14
  # Model load
15
  key = os.getenv("ONNX_KEY")
16
  cipher = Fernet(key)
17
 
18
+ VERSION = "0.0.3"
19
  TITLE = f"DVPI beregnings API (version {VERSION})"
20
+ DESCRIPTION = "Beregn Dansk Vandløbs Plante Indeks (DVPI) fra dækningsgrad af plantearter. Beregningen er baseret på en model som efterligner DVPI beregningsmetoden og er dermed ikke eksakt, usikkerheden er i gennemsnit **±0.017 EQR-enheder** og **R<sup>2</sup>=0.98** når den sammenlignes med den originale. Kan der ikke beregnes en værdi, returneres EQR=0 og DVPI=0."
21
+ URL = "http://localhost:8000" #https://kennethtm-dvpi.hf.space
22
 
23
  # Load ONNX model and species mappings
24
+ with open("model_v3.bin", "rb") as f:
25
  encrypted = f.read()
26
  decrypted = cipher.decrypt(encrypted)
27
  ort_session = ort.InferenceSession(decrypted)
28
 
29
+ # Load metadata
30
+ with open("metadata_v3.bin", "rb") as f:
31
+ encrypted = f.read()
32
+ decrypted = cipher.decrypt(encrypted)
33
+ metadata = pkl.loads(decrypted)
34
+
35
+ latinname2stancode = metadata["latinname2stancode"]
36
+ valid_taxacodes = metadata["valid_taxacodes"]
37
+ normalizer_1 = metadata["normalizer_1"]
38
+ normalizer_2 = metadata["normalizer_2"]
39
+ taxacode2idx = metadata["taxacode2idx"]
40
+
41
+ # Preprocess species
42
+ def preprocess_species(species: dict[int: float]) -> dict[int: float]:
43
+ # Apply filter 1
44
+ intermediate_species = {}
45
+ for sccode, value in species.items():
46
+ if sccode in normalizer_1:
47
+ new_sccode = normalizer_1[sccode]
48
+ if new_sccode in intermediate_species:
49
+ intermediate_species[new_sccode] += value
50
+ else:
51
+ intermediate_species[new_sccode] = value
52
+
53
+ # Apply filter 2
54
+ final_species = {}
55
+ for sccode, value in intermediate_species.items():
56
+ if sccode in normalizer_2:
57
+ if normalizer_2[sccode] is not None:
58
+ new_sccode = normalizer_2[sccode]
59
+ if new_sccode in final_species:
60
+ final_species[new_sccode] += value
61
+ else:
62
+ final_species[new_sccode] = value
63
+ else:
64
+ final_species[sccode] = value
65
+
66
+ # filter valid taxacodes
67
+ final_species = {taxacode: value for taxacode, value in final_species.items() if taxacode in valid_taxacodes}
68
+
69
+ return final_species
70
 
 
 
71
 
72
  class SpeciesCover(BaseModel):
73
+ species: dict[int, Annotated[float, Field(ge=0, le=100)]]
74
 
75
  model_config = {
76
  "json_schema_extra": {
77
  "examples": [{
78
  "species": {
79
+ 6458: 25.0,
80
+ 4158: 15.5,
81
+ 7208: 10.0
82
  }
83
  }]
84
  }
85
  }
86
 
 
87
  class EQRResult(BaseModel):
88
+ EQR: float
89
  DVPI: int
90
  version: str = VERSION
91
 
 
105
  else:
106
  return 5
107
 
108
+
109
  # FastAPI routes
110
  @app.post("/dvpi")
111
  def predict(cover_data: SpeciesCover) -> EQRResult:
112
  """Predict EQR and DVPI from species cover data"""
 
 
113
 
114
+ species_preproc = preprocess_species(cover_data.species)
115
+
116
+ input_vector = np.zeros((1, len(valid_taxacodes)))
117
 
118
+ for species, cover in species_preproc.items():
119
+ idx = taxacode2idx[species]
 
120
  input_vector[0, idx] = cover
121
+
122
+ if np.sum(input_vector) == 0:
123
+ return EQRResult(EQR=0, DVPI=0)
124
 
 
125
  input_name = ort_session.get_inputs()[0].name
126
  ort_inputs = {input_name: input_vector.astype(np.float32)}
127
+ _, output_2 = ort_session.run(None, ort_inputs)
128
 
129
+ eqr = float(output_2[0][0])
130
+ eqr = 1 if eqr > 1 else eqr
131
  dvpi = eqr_to_dvpi(eqr)
132
 
133
+ return EQRResult(EQR=round(eqr, 3), DVPI=dvpi)
 
 
 
 
 
134
 
135
  # Gradio app
136
+ def add_entry(species, cover, current_dict) -> tuple[dict, str]:
137
 
138
  current_dict[species] = cover
 
139
  return current_dict, current_dict
140
 
141
  def gradio_predict(cover_data: dict):
142
 
143
  if len(cover_data) == 0:
144
  return {}
145
+
146
+ cover_data_code = {latinname2stancode[species]: cover for species, cover in cover_data.items()}
147
 
148
+ data = SpeciesCover(species=cover_data_code)
149
  result = predict(data)
150
 
151
  return result.model_dump()
 
157
 
158
  with gr.Tab(label = "Beregner"):
159
 
160
+ gr.Markdown("Beregning er baseret på samfund af plantearter og deres dækningsgrad. Når API'et bruges anvendes arternes [Stancode](https://dce.au.dk/overvaagning/stancode/stancodelister) (SC1064) - se 'Dokumentation' for eksempel brug.")
161
 
162
  current_dict = gr.State({})
163
 
164
  with gr.Row():
165
+ species_choices = sorted(list(latinname2stancode.keys()))
166
+ species_input = gr.Dropdown(choices=species_choices, label="Vælg art")
167
  cover_input = gr.Number(label="Dækningsgrad (%)", minimum=0, maximum=100)
168
 
169
  with gr.Row():
 
181
  add_btn.click(
182
  add_entry,
183
  inputs=[species_input, cover_input, current_dict],
184
+ outputs=[current_dict, list_display],
185
+ show_api=False
186
  )
187
 
188
  reset_btn.click(
189
  reset_dict,
190
  inputs=[],
191
+ outputs=[current_dict, list_display, results],
192
+ show_api=False
193
  )
194
 
195
  calc_btn.click(
196
  gradio_predict,
197
  inputs=[current_dict],
198
+ outputs=results,
199
+ show_api=False
200
  )
201
 
202
  gr.Markdown("App og model af Kenneth Thorø Martinsen.")
203
 
204
  with gr.Tab(label="Dokumentation"):
205
 
 
206
  gr.Markdown("## Eksempel på brug af API")
207
  gr.Markdown(f"API dokumentation kan findes på [{URL}/docs]({URL}/docs)")
208
  gr.Markdown("### Python")
 
212
 
213
  data = {{
214
  "species": {{
215
+ 6458: 25.0,
216
+ 4158: 15.5,
217
+ 7208: 10.0
218
  }}
219
  }}
220
 
 
228
  library(jsonlite)
229
 
230
  data <- list(species = list(
231
+ 6458 = 25.0,
232
+ 4158 = 15.5,
233
+ 7208 = 10.0
234
  ))
235
 
236
  response <- POST("{URL}/dvpi",