deepflash2 / app.py
matjesg's picture
Update app.py
807aa27
raw
history blame
1.23 kB
import numpy as np
import gradio as gr
import onnxruntime as ort
from matplotlib import pyplot as plt
from huggingface_hub import hf_hub_download
model = hf_hub_download(repo_id="matjesg/cFOS_in_HC", filename="ensemble.onnx")
def create_model_for_provider(model_path, provider="CPUExecutionProvider"):
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(str(model_path), options, providers=[provider])
session.disable_fallback()
return session
ort_session = create_model_for_provider(model)
def inference(img):
img = img[...,:1]/255
ort_inputs = {ort_session.get_inputs()[0].name: img.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]*255
title="deepflash2"
description="deepflash2 is a deep-learning pipeline for segmentation of ambiguous microscopic images."
examples=[['1599.tif']]
gr.Interface(inference,
gr.inputs.Image(type="numpy"),
gr.outputs.Image(),
title=title,
description=description,
examples=examples
).launch(share=True)