import gradio as gr import torch import os import numpy as np import SimpleITK as sitk from scipy.ndimage import zoom from resnet_gn import resnet50 import pickle def load_from_pkl(load_path): data_input = open(load_path, 'rb') read_data = pickle.load(data_input) data_input.close() return read_data Image_3D = None Current_name = None ALL_message = load_from_pkl(r'./label0601.pkl') Model_Paht = r'./model_epoch62.pth.tar' checkpoint = torch.load(Model_Paht,map_location='cpu') classnet = resnet50( num_classes=1, sample_size=128, sample_duration=8) classnet.load_state_dict(checkpoint['model_dict']) def resize3D(img, aimsize, order = 3): """ :param img: 3D array :param aimsize: list, one or three elements, like [256], or [256,56,56] :return: """ _shape =img.shape if len(aimsize)==1: aimsize = [aimsize[0] for _ in range(3)] if aimsize[0] is None: return zoom(img, (1, aimsize[1] / _shape[1], aimsize[2] / _shape[2]),order=order) # resample for cube_size if aimsize[1] is None: return zoom(img, (aimsize[0] / _shape[0], 1, aimsize[2] / _shape[2]),order=order) # resample for cube_size if aimsize[2] is None: return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], 1),order=order) # resample for cube_size return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size def inference(): model = classnet data = Image_3D device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.eval() all_loss = 0 length = 0 with torch.no_grad(): data = torch.from_numpy(data) image = torch.unsqueeze(data, 0) patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in}) # Pre : Prediction Result pre_probs = model(patch_data) # pre_probs = F.sigmoid(pre_probs)#todo pre_flat = pre_probs.view(-1) a = 5 np.round(pre_flat.numpy()[0], decimals=2) #(1-pre_flat.numpy()[0]).astype(np.float32) #pre_flat.numpy()[0].astype(np.float32) p = float(np.round(pre_flat.numpy()[0], decimals=2)) n = float(np.round(1-p, decimals=2)) return {'急性期': n, '亚急性期': p} import gradio as gr import numpy as np import nibabel as nib import os import tempfile # 创建一个函数,接收3D数据并返回预测结果 def predict_3d(data): # 在这里编写您的3D数据处理和预测逻辑 # 对于示例目的,这里只返回输入数据的最大值作为预测结果 result = np.max(data) return result # 创建一个用于读取和展示NIfTI数据的Gradio接口函数 def interface(): # 创建一个自定义输入组件,用于读取NIfTI数据 input_component = gr.inputs.File(label="Upload NIfTI file") # 创建一个输出组件,用于展示预测结果 output_component = gr.outputs.Textbox() # 定义预测函数,接收输入数据并调用predict_3d函数进行预测 def predict(input_file): # 加载NIfTI数据 # temp_dir = tempfile.mkdtemp() # temp_file = os.path.join(temp_dir, "temp_file") # shutil.copyfile(file.name, temp_file) nifti_data = nib.load(input_file.name) # 将NIfTI数据转换为NumPy数组 data = np.array(nifti_data.dataobj) # 在这里进行必要的数据预处理,例如缩放、归一化等 # 调用predict_3d函数进行预测 result = predict_3d(data) # 将预测结果转换为字符串并返回 return str(result),str(result) # 创建Gradio接口,将输入组件和输出组件传递给Interface函数 with gr.Box(): gr.Textbox(label="First") gr.Textbox(label="Last") iface_1 = gr.Interface(fn=predict, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=gr.Box) return iface def get_Image_reslice(input_file): '''得到图像 返回随即层''' global Image_3D global Current_name if isinstance(input_file, str): input_file=input_file else: input_file=input_file.name Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file)) Current_name = input_file.split(os.sep)[-1].split('.')[0].rsplit('_',1)[0] Image_3D = (np.max(Image_3D)-Image_3D)/(np.max(Image_3D)-np.min(Image_3D)) random_z = np.random.randint(0, Image_3D.shape[0]) image_slice_z = Image_3D[random_z,:,:] random_y = np.random.randint(0, Image_3D.shape[1]) image_slice_y = Image_3D[:, random_y, :] random_x = np.random.randint(0, Image_3D.shape[2]) image_slice_x = Image_3D[:, :, random_x] # return zoom(image_slice_z, (10 / image_slice_z.shape[0], 10 / image_slice_z.shape[1]), order=3) , \ # zoom(image_slice_y, (10 / image_slice_y.shape[0], 10 / image_slice_y.shape[1]), order=3), \ # zoom(image_slice_x, (10 / image_slice_x.shape[0], 10 / image_slice_x.shape[1]), order=3) return image_slice_z, \ image_slice_y, \ image_slice_x, random_z,random_y,random_x def change_image_slice_x(slice): image_slice = Image_3D[:, :, slice-1] return image_slice def change_image_slice_y(slice): image_slice = Image_3D[:, slice-1, :] return image_slice def change_image_slice_z(slice): image_slice = Image_3D[slice-1,:,:] return image_slice def get_medical_message(): global Current_name if Current_name==None: return '请先加载数据',' ' else: past = ALL_message[Current_name]['past'] now = ALL_message[Current_name]['now'] return past, now class App: def __init__(self): self.demo = None self.main() def main(self): # get_name = gr.Interface(lambda name: name, inputs="textbox", outputs="textbox") # prepend_hello = gr.Interface(lambda name: f"Hello {name}!", inputs="textbox", outputs="textbox") # append_nice = gr.Interface(lambda greeting: f"{greeting} Nice to meet you!", # inputs="textbox", outputs=gr.Textbox(label="Greeting")) #iface_1 = gr.Interface(fn=get_Image_reslice, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=[,gr.Image(shape=(5, 5)),gr.Image(shape=(5, 5))]) with gr.Blocks() as demo: inp = gr.inputs.File(label="Upload NIfTI file") btn1 = gr.Button("Upload Data") with gr.Tab("Image"): with gr.Row(): with gr.Column(scale=1): out1 = gr.Image(shape=(10, 10)) slider1 = gr.Slider(1, 128, label='z轴层数', step=1, interactive=True) with gr.Column(scale=1): out2 = gr.Image(shape=(10, 10)) slider2 = gr.Slider(1, 256, label='y轴层数', step=1, interactive=True) with gr.Column(scale=1): out3 = gr.Image(shape=(10, 10)) slider3 = gr.Slider(1, 128, label='x轴层数', step=1, interactive=True) btn1.click(get_Image_reslice, inp, [out1, out2, out3,slider1,slider2,slider3]) slider3.change(change_image_slice_x,inputs=slider3,outputs=out3) slider2.change(change_image_slice_y, inputs=slider2, outputs=out2) slider1.change(change_image_slice_z, inputs=slider1, outputs=out1) with gr.Tab("Medical Information"): with gr.Row(): with gr.Column(scale=1): btn2 = gr.Button(label="临床信息") out4 = gr.Textbox(label="患病史") out6 = gr.Textbox(label="现病史") with gr.Column(scale=1): btn3 = gr.Button("分期结果") out5 = gr.Label(num_top_classes=2,label='分期结果') btn3.click(inference, inputs=None, outputs=out5) btn2.click(get_medical_message, inputs=None, outputs=[out4,out6]) #demo = gr.Series(get_name, prepend_hello, append_nice) gr.Markdown("##Examples") gr.Examples( #examples=r'F:\WorkSpacing\XS_data\FenQi\chuli_data\ALL\358small_exp4_cut_128_256_128\1093978_A_L_MRI.nii.gz', examples=[os.path.join(os.path.dirname(__file__), "1093978_A_L_MRI.nii.gz")], inputs="text", outputs=[out1, out2, out3,slider1,slider2,slider3], fn=get_Image_reslice, cache_examples=True, ) demo.launch() app = App() # with gr.Blocks() as demo: # with gr.Row(): # with gr.Column(scale=1): # text1 = gr.Textbox() # text2 = gr.Textbox() # with gr.Column(scale=4): # btn1 = gr.Button("Button 1") # btn2 = gr.Button("Button 2") # demo.launch()