import gradio as gr import os from omegaconf import OmegaConf from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen import torch import numpy as np import cv2 from PIL import Image import torchvision.transforms as T device = "cuda" if torch.cuda.is_available() else "cpu" exp_path = "model" class BetterCenterCrop(T.CenterCrop): def __call__(self, img): h = img.shape[-2] w = img.shape[-1] dim = min(h, w) return T.functional.center_crop(img, dim) class ImageLoader: def __init__(self, path) -> None: self.path = path self.all_files = os.listdir(path) self.transform = T.Compose([ T.ToTensor(), BetterCenterCrop((112, 112)), T.Resize((112, 112)), ]) def get_image(self): idx = np.random.randint(0, len(self.all_files)) img = Image.open(os.path.join(self.path, self.all_files[idx])) return img class Context: def __init__(self, path, device): self.path = path self.config_path = os.path.join(path, "config.yaml") self.weight_path = os.path.join(path, "merged.pt") self.config = OmegaConf.load(self.config_path) self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration) self.im_load = ImageLoader("echo_images") unets = [] for i, (k, v) in enumerate(self.config.unets.items()): unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen del self.config.imagen.elucidated imagen = imagen_klass( unets = unets, **OmegaConf.to_container(self.config.imagen), # type: ignore ) self.trainer = ImagenTrainer( imagen = imagen, **self.config.trainer ).to(device) print("Loading weights from", self.weight_path) additional_data = self.trainer.load(self.weight_path) print("Loaded weights from", self.weight_path) def reshape_image(self, image): try: image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy() return image except: return None def load_random_image(self): print("Loading random image") image = self.im_load.get_image() return image def generate_video(self, image, lvef, cond_scale): print("Generating video") print(f"lvef: {lvef}, cond_scale: {cond_scale}") image = self.im_load.transform(image).unsqueeze(0) sample_kwargs = {} sample_kwargs = { "text_embeds": torch.tensor([[[lvef/100.0]]]), "cond_scale": cond_scale, "cond_images": image, } self.trainer.eval() with torch.no_grad(): video = self.trainer.sample( batch_size=1, video_frames=self.config.dataset.num_frames, **sample_kwargs, use_tqdm = True, ).detach().cpu() # C x F x H x W if video.shape[-3:] != (64, 112, 112): video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False) video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app path = f"tmp/{uid}.mp4" video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy() out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112)) for i in video: out.write(i) out.release() return path context = Context(exp_path, device) with gr.Blocks(css="style.css") as demo: with gr.Row(): gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis") with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(scale=3, variant="panel"): text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ") with gr.Column(scale=1, min_width="226"): image = gr.Image(interactive=True) with gr.Column(scale=1, min_width="226"): video = gr.Video(interactive=False) slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True) slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True) with gr.Row(): img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)") run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀") image.change(context.reshape_image, inputs=[image], outputs=[image]) img_btn.click(context.load_random_image, inputs=[], outputs=[image]) run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video]) if __name__ == "__main__": demo.queue() demo.launch()