matjesg commited on
Commit
5a6b006
·
1 Parent(s): d49d239

Create new file

Browse files
Files changed (1) hide show
  1. app_onnx.py +43 -0
app_onnx.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import onnxruntime as ort
4
+ from matplotlib import pyplot as plt
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
8
+ options = ort.SessionOptions()
9
+ options.intra_op_num_threads = 1
10
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
11
+ session = ort.InferenceSession(str(model_path), options, providers=[provider])
12
+ session.disable_fallback()
13
+ return session
14
+
15
+ def inference(repo_id, model_name, img):
16
+ model = hf_hub_download(repo_id=repo_id, filename=model_name)
17
+ ort_session = create_model_for_provider(model)
18
+ n_channels = ort_session.get_inputs()[0].shape[-1]
19
+
20
+ img = img[...,:n_channels]/255
21
+ ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
22
+
23
+ ort_outs = ort_session.run(None, ort_inputs)
24
+
25
+ return ort_outs[0]*255, ort_outs[2]/0.25
26
+
27
+ title="deepflash2"
28
+ description='deepflash2 is a deep-learning pipeline for the segmentation of ambiguous microscopic images.\n deepflash2 uses deep model ensembles to achieve more accurate and reliable results. Thus, inference time will be more than a minute in this space.'
29
+ examples=[['matjesg/deepflash2_demo', 'cFOS_ensemble.onnx', 'cFOS_example.png'],
30
+ ['matjesg/deepflash2_demo', 'YFP_ensemble.onnx', 'YFP_example.png']
31
+ ]
32
+
33
+ gr.Interface(inference,
34
+ [gr.inputs.Textbox(placeholder='e.g., matjesg/cFOS_in_HC', label='repo_id'),
35
+ gr.inputs.Textbox(placeholder='e.g., ensemble.onnx', label='model_name'),
36
+ gr.inputs.Image(type='numpy', label='Input image')
37
+ ],
38
+ [gr.outputs.Image(label='Segmentation Mask'),
39
+ gr.outputs.Image(label='Uncertainty Map')],
40
+ title=title,
41
+ description=description,
42
+ examples=examples,
43
+ ).launch()