import os import gradio as gr import matplotlib.pyplot as plt import numpy as np import osail_utils import pandas as pd import skimage from mediffusion import DiffusionModule import monai as mn import torch # Loading the model for inference model = DiffusionModule("./diffusion_configs.yaml") model.load_ckpt("./data/model.ckpt") model.cuda().half() model.eval(); # Loading a baseline noise for making predictions seed = 3407 np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half() # Model helper functions def create_ds(img_paths): if type(img_paths) == str: img_paths = [img_paths] data_list = [{"img": img_path} for img_path in img_paths] # Get the transforms Ts_list = [ osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True), mn.transforms.EnsureChannelFirstD( keys=["img"], channel_dim="no_channel" ), mn.transforms.ResizeD( keys=["img"], spatial_size=(256, 256), mode=["bicubic"], ), mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1), mn.transforms.ToTensorD(keys=["img"], track_meta=None), mn.transforms.SelectItemsD(keys=["img"]), ] return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list)) def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"): global model global BASELINE_NOISE # Create the image dataset if cls_batch is not None: ds = create_ds([img_path]*len(cls_batch)) else: ds = create_ds(img_path) dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False) input_batch = next(iter(dl)) original_imgs = input_batch["img"].detach().cpu().numpy() # Create the classifier condition if not provided if cls_batch is None: fp = torch.zeros(768) if rotate_to_standard or angles is None: angles = [1000, 1000, 1000] cls_value = torch.tensor([2, *angles, *fp]) else: cls_value = torch.tensor([1, *angles, *fp]) cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half() # Generate noise noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1) model_kwargs = { "cls": cls_batch, "concat": input_batch["img"].cuda().half(), } # Make predictions preds = model.predict( noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler ) adjusted_preds = list() for pred, original_img in zip(preds, original_imgs): adjusted_pred = pred.detach().cpu().numpy().squeeze() original_img = original_img.squeeze() adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img) adjusted_preds.append(adjusted_pred) return adjusted_preds # Gradio helper functions current_img = None live_preds = None def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): global current_img angles = [float(xt), float(yt), float(zt)] out_img = make_predictions(img_path, angles)[0] if not add_bone_cmap: print(out_img.shape) return out_img cmap = plt.get_cmap('bone') out_img = cmap(out_img) out_img = (out_img[..., :3] * 255).astype(np.uint8) current_img = out_img return out_img def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False): global current_img out_img = make_predictions(img_path, rotate_to_standard=True)[0] if not add_bone_cmap: return out_img cmap = plt.get_cmap('bone') out_img = cmap(out_img) out_img = (out_img[..., :3] * 255).astype(np.uint8) current_img = out_img return out_img def use_current_btn_fn(input_img): return input_img def make_live_btn_fn(img_path, axis, add_bone_cmap=False): global live_preds base_angles = list(range(-20, 21, 1)) base_angles = [float(i) for i in base_angles] if axis.lower() == "axis x": all_angles = [[i, 0, 0] for i in base_angles] elif axis.lower() == "axis y": all_angles = [[0, i, 0] for i in base_angles] elif axis.lower() == "axis z": all_angles = [[0, 0, i] for i in base_angles] fp = torch.zeros(768) cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles]) live_preds = make_predictions(img_path, cls_batch=cls_batch) live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)} return img_path def rotate_live_img_fn(angle, add_bone_cmap=False): global live_img global live_preds if live_img is not None: if angle == 0: return live_img return live_preds[float(angle)] css_style = "./style.css" callback = gr.CSVLogger() with gr.Blocks(css=css_style) as app: gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title") gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note") gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note") with gr.TabItem("Single Rotation"): with gr.Row(): input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') with gr.Row(): gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], inputs = [input_img], label = "Xray Examples", elem_id='examples' ) gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], inputs = [input_img], label = "DRR Examples", elem_id='examples' ) with gr.Row(): gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') with gr.Row(): with gr.Column(scale=1): xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Column(scale=1): yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Column(scale=1): zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Row(): rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') with gr.Row(): rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button') with gr.Row(): use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button') rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img) use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img) with gr.TabItem("Live Rotation"): with gr.Row(): live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs') with gr.Row(): gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], inputs = [live_img], label = "Xray Examples", elem_id='examples' ) gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], inputs = [live_img], label = "DRR Examples", elem_id='examples' ) with gr.Row(): gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text') with gr.Row(): axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X') live_btn = gr.Button("Make Live!", elem_classes='make_live_button') with gr.Row(): gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text') with gr.Row(): slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True) live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img) slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img) try: app.close() gr.close_all() except: pass demo = app.launch( max_threads=4, share=True, inline=False, show_api=False, show_error=True, server_port=1902, server_name="0.0.0.0", )