Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from omegaconf import OmegaConf | |
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torchvision.transforms as T | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
exp_path = "model" | |
class BetterCenterCrop(T.CenterCrop): | |
def __call__(self, img): | |
h = img.shape[-2] | |
w = img.shape[-1] | |
dim = min(h, w) | |
return T.functional.center_crop(img, dim) | |
class ImageLoader: | |
def __init__(self, path) -> None: | |
self.path = path | |
self.all_files = os.listdir(path) | |
self.transform = T.Compose([ | |
T.ToTensor(), | |
BetterCenterCrop((112, 112)), | |
T.Resize((112, 112)), | |
]) | |
def get_image(self): | |
idx = np.random.randint(0, len(self.all_files)) | |
img = Image.open(os.path.join(self.path, self.all_files[idx])) | |
return img | |
class Context: | |
def __init__(self, path, device): | |
self.path = path | |
self.config_path = os.path.join(path, "config.yaml") | |
self.weight_path = os.path.join(path, "merged.pt") | |
self.config = OmegaConf.load(self.config_path) | |
self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration) | |
self.im_load = ImageLoader("echo_images") | |
unets = [] | |
for i, (k, v) in enumerate(self.config.unets.items()): | |
unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore | |
imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen | |
del self.config.imagen.elucidated | |
imagen = imagen_klass( | |
unets = unets, | |
**OmegaConf.to_container(self.config.imagen), # type: ignore | |
) | |
self.trainer = ImagenTrainer( | |
imagen = imagen, | |
**self.config.trainer | |
).to(device) | |
print("Loading weights from", self.weight_path) | |
additional_data = self.trainer.load(self.weight_path) | |
print("Loaded weights from", self.weight_path) | |
def reshape_image(self, image): | |
try: | |
image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy() | |
return image | |
except: | |
return None | |
def load_random_image(self): | |
print("Loading random image") | |
image = self.im_load.get_image() | |
return image | |
def generate_video(self, image, lvef, cond_scale): | |
print("Generating video") | |
print(f"lvef: {lvef}, cond_scale: {cond_scale}") | |
image = self.im_load.transform(image).unsqueeze(0) | |
sample_kwargs = {} | |
sample_kwargs = { | |
"text_embeds": torch.tensor([[[lvef/100.0]]]), | |
"cond_scale": cond_scale, | |
"cond_images": image, | |
} | |
self.trainer.eval() | |
with torch.no_grad(): | |
video = self.trainer.sample( | |
batch_size=1, | |
video_frames=self.config.dataset.num_frames, | |
**sample_kwargs, | |
use_tqdm = True, | |
).detach().cpu() # C x F x H x W | |
if video.shape[-3:] != (64, 112, 112): | |
video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False) | |
video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see | |
uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app | |
path = f"tmp/{uid}.mp4" | |
video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy() | |
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112)) | |
for i in video: | |
out.write(i) | |
out.release() | |
return path | |
context = Context(exp_path, device) | |
with gr.Blocks(css="style.css") as demo: | |
with gr.Row(): | |
gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=3, variant="panel"): | |
text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ") | |
with gr.Column(scale=1, min_width="226"): | |
image = gr.Image(interactive=True) | |
with gr.Column(scale=1, min_width="226"): | |
video = gr.Video(interactive=False) | |
slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True) | |
slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True) | |
with gr.Row(): | |
img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)") | |
run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀") | |
image.change(context.reshape_image, inputs=[image], outputs=[image]) | |
img_btn.click(context.load_random_image, inputs=[], outputs=[image]) | |
run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video]) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |