import gradio as gr import torch import os import numpy as np import SimpleITK as sitk from scipy.ndimage import zoom from resnet_gn import resnet50 import pickle #import tempfile def load_from_pkl(load_path): data_input = open(load_path, 'rb') read_data = pickle.load(data_input) data_input.close() return read_data Image_3D = None Current_name = None ALL_message = load_from_pkl(r'./label0601.pkl') Model_Paht = r'./model_epoch62.pth.tar' checkpoint = torch.load(Model_Paht, map_location='cpu') classnet = resnet50( num_classes=1, sample_size=128, sample_duration=8) classnet.load_state_dict(checkpoint['model_dict']) def resize3D(img, aimsize, order=3): """ :param img: 3D array :param aimsize: list, one or three elements, like [256], or [256,56,56] :return: """ _shape = img.shape if len(aimsize) == 1: aimsize = [aimsize[0] for _ in range(3)] if aimsize[0] is None: return zoom(img, (1, aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size if aimsize[1] is None: return zoom(img, (aimsize[0] / _shape[0], 1, aimsize[2] / _shape[2]), order=order) # resample for cube_size if aimsize[2] is None: return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], 1), order=order) # resample for cube_size return zoom(img, (aimsize[0] / _shape[0], aimsize[1] / _shape[1], aimsize[2] / _shape[2]), order=order) # resample for cube_size def inference(): global Image_small_3D model = classnet data = Image_small_3D device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.eval() all_loss = 0 length = 0 try: with torch.no_grad(): data = torch.from_numpy(data) image = torch.unsqueeze(data, 0) patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in}) # Pre : Prediction Result pre_probs = model(patch_data) # pre_probs = F.sigmoid(pre_probs)#todo pre_flat = pre_probs.view(-1) np.round(pre_flat.numpy()[0], decimals=2) # (1-pre_flat.numpy()[0]).astype(np.float32) # pre_flat.numpy()[0].astype(np.float32) # p = float(np.round(pre_flat.numpy()[0], decimals=2)) # n = float(np.round(1 - p, decimals=2)) p = np.round(float(pre_flat.numpy()[0]), decimals=2) n = np.round(float(1 - p), decimals=2) return {'急性期': n, '亚急性期': p} except: return ' ' def get_Image_reslice(input_file): '''得到图像 返回随即层''' global Image_3D global Current_name global Input_File if isinstance(input_file, str): input_file = input_file else: input_file = input_file.name Input_File = input_file print(input_file) Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file)) Current_name = input_file.split(os.sep)[-1].split('.')[0].rsplit('_', 1)[0] Image_3D = (np.max(Image_3D) - Image_3D) / (np.max(Image_3D) - np.min(Image_3D)) random_z = np.random.randint(0, Image_3D.shape[0]) image_slice_z = Image_3D[random_z, :, :] random_y = np.random.randint(0, Image_3D.shape[1]) image_slice_y = Image_3D[:, random_y, :] random_x = np.random.randint(0, Image_3D.shape[2]) image_slice_x = Image_3D[:, :, random_x] # return zoom(image_slice_z, (10 / image_slice_z.shape[0], 10 / image_slice_z.shape[1]), order=3) , \ # zoom(image_slice_y, (10 / image_slice_y.shape[0], 10 / image_slice_y.shape[1]), order=3), \ # zoom(image_slice_x, (10 / image_slice_x.shape[0], 10 / image_slice_x.shape[1]), order=3) return image_slice_z, \ image_slice_y, \ image_slice_x, random_z, random_y, random_x, '影像数据加载成功' def get_ROI(input_file): '''得到图像 返回随即层''' global ROI_3D if isinstance(input_file, str): input_file = input_file else: input_file = input_file.name Image_3D = sitk.GetArrayFromImage(sitk.ReadImage(input_file)) ROI_3D = Image_3D unique_elements = np.unique(ROI_3D) a = 5 if np.where(unique_elements>1)[0]: return '这个数据没有经过二值化' else: return '感兴趣区域加载成功' def change_image_slice_x(slice): image_slice = Image_3D[:, :, slice - 1] cut_thre = np.percentile(image_slice, 99.9) # 直方图99.9%右侧值不要 image_slice[image_slice >= cut_thre] = cut_thre image_slice = (((np.max(image_slice) -image_slice)/(np.max(image_slice) - np.min(image_slice)))*255).astype(np.int16) a = 5 return image_slice def change_image_slice_y(slice): image_slice = Image_3D[:, slice - 1, :] cut_thre = np.percentile(image_slice, 99.9) # 直方图99.9%右侧值不要 image_slice[image_slice >= cut_thre] = cut_thre image_slice = (((np.max(image_slice) - image_slice) / (np.max(image_slice) - np.min(image_slice))) * 255).astype( np.int16) return image_slice def change_image_slice_z(slice): image_slice = Image_3D[slice - 1, :, :] cut_thre = np.percentile(image_slice, 99.9) # 直方图99.9%右侧值不要 image_slice[image_slice >= cut_thre] = cut_thre image_slice = (((np.max(image_slice) - image_slice) / (np.max(image_slice) - np.min(image_slice))) * 255).astype(np.int16) return image_slice def get_medical_message(): global Current_name if Current_name == None: return '请先加载数据', ' ' else: past = ALL_message[Current_name]['past'] now = ALL_message[Current_name]['now'] return past, now def clear_all(): global Image_3D global Current_name Current_name = None Image_3D = None return np.ones((10, 10)), np.ones((10, 10)), np.ones((10, 10)), '', '', ' ',"尚未进行预处理 请先预处理再按“分期结果”按钮","尚未加载影像数据","尚未加载感兴趣区域" def get_box(mask): """ :param mask: array,输入金标准图像 :return: """ # 得到boxx坐标 # 计算得到bbox,形式为[dim0min, dim0max, dim1min, dim1max, dim2min, dim2max] indexx = np.where(mask > 0.) # 返回坐标,几维就是几组坐标,坐标纵向看 dim0min, dim0max, dim1min, dim1max, dim2min, dim2max = [np.min(indexx[0]), np.max(indexx[0]), np.min(indexx[1]), np.max(indexx[1]), np.min(indexx[2]), np.max(indexx[2])] bbox = [dim0min, dim0max, dim1min, dim1max, dim2min, dim2max] return bbox def arry_crop_3D(img,mask,ex_pix): ''' 得到小图,并外扩 :param img array 3D :param mask array :param ex_pix: list [a,b,c] 向两侧各自外扩多少 维度顺序与输入一致 :param z_waikuo:z轴是否外扩,默认第一维 务必提前确认 !! ''' if len(ex_pix)==1: ex_pix=[ex_pix[0] for _ in range(3)] elif len(ex_pix) == 2: print('如果z轴不外扩,第一维请输入0') sys.exit() [dim0min, dim0max, dim1min, dim1max, dim2min, dim2max] = get_box(mask) #判断能否外扩 dim0,dim1,dim2 = img.shape dim1_l_index = np.clip(dim1min-ex_pix[1],0 ,dim1) #dim1外扩后左边的坐标,若触碰边界,则尽量外扩至边界 dim1_r_index = np.clip(dim1max + ex_pix[1], 0, dim1) dim2_l_index = np.clip(dim2min - ex_pix[2], 0, dim2) dim2_r_index = np.clip(dim2max + ex_pix[2], 0, dim2) fina_img = img[:, dim1_l_index:dim1_r_index+1, dim2_l_index:dim2_r_index+1] fina_mask = mask[:, dim1_l_index:dim1_r_index+1, dim2_l_index:dim2_r_index+1] if ex_pix[0]: dim0_l_index = np.clip(dim0min - ex_pix[0], 0, dim0) dim0_r_index = np.clip(dim0max + ex_pix[0], 0, dim0) fina_img = fina_img[dim0_l_index:dim0_r_index+1, :, :] fina_mask = fina_mask[dim0_l_index:dim0_r_index+1, :, :] else: #不外扩 print('dim0 不外扩') dim0_l_index = dim0min dim0_r_index = dim0max fina_img = fina_img[dim0_l_index:dim0_r_index+1, :, :] fina_mask = fina_mask[dim0_l_index:dim0_r_index+1, :, :] return fina_img, fina_mask def data_pretreatment(): global Image_3D global ROI_3D global Image_small_3D global Current_name global Input_File if Image_3D.all() ==None: return '没有数据' else: roi = ROI_3D waikuo = [4, 4, 4] fina_img, fina_mask = arry_crop_3D(Image_3D,roi,waikuo) cut_thre = np.percentile(fina_img, 99.9) # 直方图99.9%右侧值不要 fina_img[fina_img >= cut_thre] = cut_thre fina_img = resize3D(fina_img, [128,256,128], order=3) fina_img = (np.max(fina_img)-fina_img)/(np.max(fina_img)-np.min(fina_img)) Image_small_3D = fina_img return '预处理结束' class App: def __init__(self): self.demo = None self.main() def main(self): # get_name = gr.Interface(lambda name: name, inputs="textbox", outputs="textbox") # prepend_hello = gr.Interface(lambda name: f"Hello {name}!", inputs="textbox", outputs="textbox") # append_nice = gr.Interface(lambda greeting: f"{greeting} Nice to meet you!", # inputs="textbox", outputs=gr.Textbox(label="Greeting")) # iface_1 = gr.Interface(fn=get_Image_reslice, inputs=gr.inputs.File(label="Upload NIfTI file"), outputs=[,gr.Image(shape=(5, 5)),gr.Image(shape=(5, 5))]) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): inp = gr.inputs.File(label="Upload MRI file") inp2 = gr.inputs.File(label="Upload ROI file") with gr.Column(scale=1): out8 = gr.Textbox(placeholder="尚未加载影像数据") out9 = gr.Textbox(placeholder="尚未加载感兴趣区域") with gr.Row(): btn1 = gr.Button("Upload MRI") btn5 = gr.Button("Upload ROI") clear = gr.Button(" Clear All") with gr.Tab("Image"): with gr.Row(): with gr.Column(scale=1): out1 = gr.Image(shape=(10, 10)) slider1 = gr.Slider(1, 128, label='z轴层数', step=1, interactive=True) with gr.Column(scale=1): out2 = gr.Image(shape=(10, 10)) slider2 = gr.Slider(1, 256, label='y轴层数', step=1, interactive=True) with gr.Column(scale=1): out3 = gr.Image(shape=(10, 10)) slider3 = gr.Slider(1, 128, label='x轴层数', step=1, interactive=True) with gr.Tab("Medical Information"): with gr.Row(): with gr.Column(scale=1): btn2 = gr.Button(value="临床信息") out4 = gr.Textbox(label="患病史") out6 = gr.Textbox(label="现病史") with gr.Column(scale=1): btn4 = gr.Button("预处理") out7 = gr.Textbox(placeholder="尚未进行预处理 请先预处理再按“分期结果”按钮", ) btn3 = gr.Button("分期结果") out5 = gr.Label(num_top_classes=2, label='分期结果') btn3.click(inference, inputs=None, outputs=out5) btn4.click(data_pretreatment, inputs=None, outputs=out7) btn2.click(get_medical_message, inputs=None, outputs=[out4, out6]) # demo = gr.Series(get_name, prepend_hello, append_nice) btn1.click(get_Image_reslice, inp, [out1, out2, out3, slider1, slider2, slider3,out8]) btn5.click(get_ROI, inputs=inp2, outputs=out9) slider3.change(change_image_slice_x, inputs=slider3, outputs=out3) slider2.change(change_image_slice_y, inputs=slider2, outputs=out2) slider1.change(change_image_slice_z, inputs=slider1, outputs=out1) clear.click(clear_all, None, [out1, out2, out3, out4, out6, out5, out7,out8,out9], queue=True) gr.Markdown('''# Examples''') gr.Examples( examples=[["./155086_A_R_MRI.nii.gz"], ["./4077798_A_L_MRI.nii.gz"]], inputs=inp, outputs=[out1, out2, out3, slider1, slider2, slider3,out8], fn=get_Image_reslice, cache_examples=True, ) gr.Examples( examples=[["./155086_A_R_ROI.nii.gz"], ["./4077798_A_L_ROI.nii.gz"]], inputs=inp2, outputs=out9, fn=get_ROI, cache_examples=True, ) demo.queue(concurrency_count=6) demo.launch(share=False) app = App()