Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
import numpy as np | |
import SimpleITK as sitk | |
from scipy.ndimage import zoom | |
from resnet_gn import resnet50 | |
import pickle | |
def load_from_pkl(load_path): | |
data_input = open(load_path, 'rb') | |
read_data = pickle.load(data_input) | |
data_input.close() | |
return read_data | |
Image_3D = None | |
Current_name = None | |
ALL_message = load_from_pkl(r'./label0601.pkl') | |
Model_Paht = r'./model_epoch62.pth.tar' | |
checkpoint = torch.load(Model_Paht,map_location='cpu') | |
a = 5 | |
classnet = resnet50( | |
num_classes=1, | |
sample_size=128, | |
sample_duration=8) | |
classnet.load_state_dict(checkpoint['model_dict']) | |
def resize3D(img, aimsize, order = 3): | |
""" | |
:param img: 3D array | |
:param aimsize: list, one or three elements, like [256], or [256,56,56] | |
:return: | |
""" | |
_shape =img.shape | |
if len(aimsize)==1: | |
aimsize = [aimsize[0] for _ in range(3)] | |
if aimsize[0] is None: | |
return zoom(img, (1, aimsize[1] / _shape[1], aimsize[2] / _shape[2]),order=order) # resample for cube_size | |
if aimsize[1] is None: | |
return zoom(img, (aimsize[0] / _shape[0], 1, aimsize[2] / _shape[2]),order=order) # resample for cube_size | |
if aimsize[2] is None: | |
return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], 1),order=order) # resample for cube_size | |
return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size | |
def inference(): | |
model = classnet | |
data = Image_3D | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.eval() | |
all_loss = 0 | |
length = 0 | |
with torch.no_grad(): | |
data = torch.from_numpy(data) | |
image = torch.unsqueeze(data, 0) | |
patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in}) | |
# Pre : Prediction Result | |
pre_probs = model(patch_data) | |
# pre_probs = F.sigmoid(pre_probs)#todo | |
pre_flat = pre_probs.view(-1) | |
a = 5 | |
np.round(pre_flat.numpy()[0], decimals=2) | |
#(1-pre_flat.numpy()[0]).astype(np.float32) | |
#pre_flat.numpy()[0].astype(np.float32) | |
p = float(np.round(pre_flat.numpy()[0], decimals=2)) | |
n = float(np.round(1-p, decimals=2)) | |
return {'急性期': n, '亚急性期': p} | |
# | |
# | |
# def image_classifier(inp): | |
# #return {'cat': 0.3, 'dog': 0.7} | |
# return inp | |
# | |
# def image_read(inp): | |
# image = sitk.GetArrayFromImage(sitk.ReadImage(inp)) | |
# ss = np.sum(image) | |
# return str(ss) | |
# | |
# | |
# def upload_file(files): | |
# file_paths = [file.name for file in files] | |
# return file_paths | |
# | |
# with gr.Blocks() as demo: | |
# file_output = gr.File() | |
# upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple") | |
# upload_button.upload(upload_file, upload_button, gr.Code('')) | |
# demo.launch() | |
import gradio as gr | |
import numpy as np | |
import nibabel as nib | |
import os | |
import tempfile | |
# 创建一个函数,接收3D数据并返回预测结果 | |
def predict_3d(data): | |
# 在这里编写您的3D数据处理和预测逻辑 | |
# 对于示例目的,这里只返回输入数据的最大值作为预测结果 | |
result = np.max(data) | |
return result | |
# 创建一个用于读取和展示NIfTI数据的Gradio接口函数 | |
def interface(): | |
# 创建一个自定义输入组件,用于读取NIfTI数据 | |
input_component = gr.inputs.File(label="Upload NIfTI file") | |
# 创建一个输出组件,用于展示预测结果 | |
output_component = gr.outputs.Textbox() | |
# 定义预测函数,接收输入数据并调用predict_3d函数进行预测 | |
def predict(input_file): | |
# 加载NIfTI数据 | |
# temp_dir = tempfile.mkdtemp() | |
# temp_file = os.path.join(temp_dir, "temp_file") | |
# shutil.copyfile(file.name, temp_file) | |
nifti_data = nib.load(input_file.name) | |
# 将NIfTI数据转换为NumPy数组 | |
data = np.array(nifti_data.dataobj) | |
# 在这里进行必要的数据预处理,例如缩放、归一化等 | |
# 调用predict_3d函数进行预测 | |
result = predict_3d(data) | |
# 将预测结果转换为字符串并返回 | |
return str(result),str(result) | |
# 创建Gradio接口,将输入组件和输出组件传递给Interface函数 | |
with gr.Box(): | |
gr.Textbox(label="First") | |
gr.Textbox(label="Last") | |
iface_1 = gr.Interface(fn=predict, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=gr.Box) | |
return iface | |
def get_Image_reslice(input_file): | |
'''得到图像 返回随即层''' | |
global Image_3D | |
global Current_name | |
Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file.name)) | |
Current_name = input_file.name.split(os.sep)[-1].split('.')[0].rsplit('_',1)[0] | |
Image_3D = (np.max(Image_3D)-Image_3D)/(np.max(Image_3D)-np.min(Image_3D)) | |
random_z = np.random.randint(0, Image_3D.shape[0]) | |
image_slice_z = Image_3D[random_z,:,:] | |
random_y = np.random.randint(0, Image_3D.shape[1]) | |
image_slice_y = Image_3D[:, random_y, :] | |
random_x = np.random.randint(0, Image_3D.shape[2]) | |
image_slice_x = Image_3D[:, :, random_x] | |
# return zoom(image_slice_z, (10 / image_slice_z.shape[0], 10 / image_slice_z.shape[1]), order=3) , \ | |
# zoom(image_slice_y, (10 / image_slice_y.shape[0], 10 / image_slice_y.shape[1]), order=3), \ | |
# zoom(image_slice_x, (10 / image_slice_x.shape[0], 10 / image_slice_x.shape[1]), order=3) | |
return image_slice_z, \ | |
image_slice_y, \ | |
image_slice_x, random_z,random_y,random_x | |
def change_image_slice_x(slice): | |
image_slice = Image_3D[:, :, slice-1] | |
return image_slice | |
def change_image_slice_y(slice): | |
image_slice = Image_3D[:, slice-1, :] | |
return image_slice | |
def change_image_slice_z(slice): | |
image_slice = Image_3D[slice-1,:,:] | |
return image_slice | |
def get_medical_message(): | |
global Current_name | |
if Current_name==None: | |
return '请先加载数据',' ' | |
else: | |
past = ALL_message[Current_name]['past'] | |
now = ALL_message[Current_name]['now'] | |
return past, now | |
class App: | |
def __init__(self): | |
self.demo = None | |
self.main() | |
def main(self): | |
# get_name = gr.Interface(lambda name: name, inputs="textbox", outputs="textbox") | |
# prepend_hello = gr.Interface(lambda name: f"Hello {name}!", inputs="textbox", outputs="textbox") | |
# append_nice = gr.Interface(lambda greeting: f"{greeting} Nice to meet you!", | |
# inputs="textbox", outputs=gr.Textbox(label="Greeting")) | |
#iface_1 = gr.Interface(fn=get_Image_reslice, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=[,gr.Image(shape=(5, 5)),gr.Image(shape=(5, 5))]) | |
with gr.Blocks() as demo: | |
inp = gr.inputs.File(label="Upload NIfTI file") | |
btn1 = gr.Button("Upload Data") | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
out1 = gr.Image(shape=(10, 10)) | |
slider1 = gr.Slider(1, 128, label='z轴层数', step=1, interactive=True) | |
with gr.Column(scale=1): | |
out2 = gr.Image(shape=(10, 10)) | |
slider2 = gr.Slider(1, 256, label='y轴层数', step=1, interactive=True) | |
with gr.Column(scale=1): | |
out3 = gr.Image(shape=(10, 10)) | |
slider3 = gr.Slider(1, 128, label='x轴层数', step=1, interactive=True) | |
btn1.click(get_Image_reslice, inp, [out1, out2, out3,slider1,slider2,slider3]) | |
slider3.change(change_image_slice_x,inputs=slider3,outputs=out3) | |
slider2.change(change_image_slice_y, inputs=slider2, outputs=out2) | |
slider1.change(change_image_slice_z, inputs=slider1, outputs=out1) | |
with gr.Tab("Medical Information"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
btn2 = gr.Button(label="临床信息") | |
out4 = gr.Textbox(label="患病史") | |
out6 = gr.Textbox(label="现病史") | |
with gr.Column(scale=1): | |
btn3 = gr.Button("分期结果") | |
out5 = gr.Label(num_top_classes=2,label='分期结果') | |
btn3.click(inference, inputs=None, outputs=out5) | |
btn2.click(get_medical_message, inputs=None, outputs=[out4,out6]) | |
#demo = gr.Series(get_name, prepend_hello, append_nice) | |
demo.launch(share=True) | |
app = App() | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# with gr.Column(scale=1): | |
# text1 = gr.Textbox() | |
# text2 = gr.Textbox() | |
# with gr.Column(scale=4): | |
# btn1 = gr.Button("Button 1") | |
# btn2 = gr.Button("Button 2") | |
# demo.launch() | |