Spaces:
Sleeping
Sleeping
Commit
·
316f1d5
0
Parent(s):
Duplicate from anon-SGXT/echocardiogram-video-diffusion
Browse filesCo-authored-by: Anonymous <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- .gitignore +3 -0
- README.md +14 -0
- app.py +145 -0
- echo_images/0X10094BA0A028EAC3.png +0 -0
- echo_images/0X1013E8A4864781B.png +0 -0
- echo_images/0X12B890B1E2E14CC4.png +0 -0
- echo_images/0X13E043A35E3EB490.png +0 -0
- echo_images/0X159BDA520C61736A.png +0 -0
- echo_images/0X15DA8D60960ABB2B.png +0 -0
- echo_images/0X16AF26F9A372EEDE.png +0 -0
- echo_images/0X17BC4EF4BF83368B.png +0 -0
- echo_images/0X1B379931357428C0.png +0 -0
- echo_images/0X1CDD9C054D8FB60D.png +0 -0
- echo_images/0X1DF7163A74801695.png +0 -0
- echo_images/0X20C397F012441121.png +0 -0
- echo_images/0X22A1A8A656653343.png +0 -0
- echo_images/0X22D7FDCF2827269E.png +0 -0
- echo_images/0X230F00FD0DF5D71C.png +0 -0
- echo_images/0X244CAB3550320216.png +0 -0
- echo_images/0X24FEF7D294B35A5B.png +0 -0
- echo_images/0X25D970C75A57B3F2.png +0 -0
- echo_images/0X277FC348812C0E79.png +0 -0
- echo_images/0X27836E538BD008A.png +0 -0
- echo_images/0X2840438B29E95F1F.png +0 -0
- echo_images/0X29A336DCE20541A0.png +0 -0
- echo_images/0X29C81728B50A2E6C.png +0 -0
- echo_images/0X2A830BC4A3A36A93.png +0 -0
- echo_images/0X2AD994F98C491FA6.png +0 -0
- echo_images/0X2BB766EF1A13DECC.png +0 -0
- echo_images/0X2DA99F9FC1DAD8A9.png +0 -0
- echo_images/0X3545F8A008B34ED0.png +0 -0
- echo_images/0X36E4468C9E659B89.png +0 -0
- echo_images/0X39CA8CC96A5D5E8B.png +0 -0
- echo_images/0X3B01B7487E3D81EA.png +0 -0
- echo_images/0X3B0D2D527C387A0E.png +0 -0
- echo_images/0X3B54A5459841DCE8.png +0 -0
- echo_images/0X3B9FBD87EE113D62.png +0 -0
- echo_images/0X3BA9F7C9DB0CF55B.png +0 -0
- echo_images/0X3DA2B290B58A6540.png +0 -0
- echo_images/0X3E2F182038897EA5.png +0 -0
- echo_images/0X3F076329C702F768.png +0 -0
- echo_images/0X4130EB4CD7ED958B.png +0 -0
- echo_images/0X42E8226CA93B7BAC.png +0 -0
- echo_images/0X45418C574D97027A.png +0 -0
- echo_images/0X45CE057EC2EB577F.png +0 -0
- echo_images/0X463A7B7D46C6CA4.png +0 -0
- echo_images/0X463C296E8E65DA97.png +0 -0
- echo_images/0X46682D67FA3FE237.png +0 -0
- echo_images/0X487B52623BC14C25.png +0 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.mp4
|
2 |
+
*.ipynb
|
3 |
+
*__pycache__*
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: EchoNet Video Diffusion
|
3 |
+
emoji: 🖤
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.17.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: anon-SGXT/echocardiogram-video-diffusion
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as T
|
10 |
+
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
exp_path = "model"
|
14 |
+
|
15 |
+
class BetterCenterCrop(T.CenterCrop):
|
16 |
+
def __call__(self, img):
|
17 |
+
h = img.shape[-2]
|
18 |
+
w = img.shape[-1]
|
19 |
+
dim = min(h, w)
|
20 |
+
|
21 |
+
return T.functional.center_crop(img, dim)
|
22 |
+
|
23 |
+
class ImageLoader:
|
24 |
+
def __init__(self, path) -> None:
|
25 |
+
self.path = path
|
26 |
+
self.all_files = os.listdir(path)
|
27 |
+
self.transform = T.Compose([
|
28 |
+
T.ToTensor(),
|
29 |
+
BetterCenterCrop((112, 112)),
|
30 |
+
T.Resize((112, 112)),
|
31 |
+
])
|
32 |
+
|
33 |
+
def get_image(self):
|
34 |
+
idx = np.random.randint(0, len(self.all_files))
|
35 |
+
img = Image.open(os.path.join(self.path, self.all_files[idx]))
|
36 |
+
return img
|
37 |
+
|
38 |
+
class Context:
|
39 |
+
def __init__(self, path, device):
|
40 |
+
self.path = path
|
41 |
+
self.config_path = os.path.join(path, "config.yaml")
|
42 |
+
self.weight_path = os.path.join(path, "merged.pt")
|
43 |
+
|
44 |
+
self.config = OmegaConf.load(self.config_path)
|
45 |
+
|
46 |
+
self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)
|
47 |
+
|
48 |
+
self.im_load = ImageLoader("echo_images")
|
49 |
+
|
50 |
+
unets = []
|
51 |
+
for i, (k, v) in enumerate(self.config.unets.items()):
|
52 |
+
unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore
|
53 |
+
|
54 |
+
imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
|
55 |
+
del self.config.imagen.elucidated
|
56 |
+
imagen = imagen_klass(
|
57 |
+
unets = unets,
|
58 |
+
**OmegaConf.to_container(self.config.imagen), # type: ignore
|
59 |
+
)
|
60 |
+
|
61 |
+
self.trainer = ImagenTrainer(
|
62 |
+
imagen = imagen,
|
63 |
+
**self.config.trainer
|
64 |
+
).to(device)
|
65 |
+
|
66 |
+
print("Loading weights from", self.weight_path)
|
67 |
+
additional_data = self.trainer.load(self.weight_path)
|
68 |
+
print("Loaded weights from", self.weight_path)
|
69 |
+
|
70 |
+
def reshape_image(self, image):
|
71 |
+
try:
|
72 |
+
image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
|
73 |
+
return image
|
74 |
+
except:
|
75 |
+
return None
|
76 |
+
|
77 |
+
def load_random_image(self):
|
78 |
+
print("Loading random image")
|
79 |
+
image = self.im_load.get_image()
|
80 |
+
return image
|
81 |
+
|
82 |
+
def generate_video(self, image, lvef, cond_scale):
|
83 |
+
print("Generating video")
|
84 |
+
print(f"lvef: {lvef}, cond_scale: {cond_scale}")
|
85 |
+
|
86 |
+
image = self.im_load.transform(image).unsqueeze(0)
|
87 |
+
|
88 |
+
sample_kwargs = {}
|
89 |
+
sample_kwargs = {
|
90 |
+
"text_embeds": torch.tensor([[[lvef/100.0]]]),
|
91 |
+
"cond_scale": cond_scale,
|
92 |
+
"cond_images": image,
|
93 |
+
}
|
94 |
+
|
95 |
+
self.trainer.eval()
|
96 |
+
with torch.no_grad():
|
97 |
+
video = self.trainer.sample(
|
98 |
+
batch_size=1,
|
99 |
+
video_frames=self.config.dataset.num_frames,
|
100 |
+
**sample_kwargs,
|
101 |
+
use_tqdm = True,
|
102 |
+
).detach().cpu() # C x F x H x W
|
103 |
+
if video.shape[-3:] != (64, 112, 112):
|
104 |
+
video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
|
105 |
+
video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
|
106 |
+
uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
|
107 |
+
path = f"tmp/{uid}.mp4"
|
108 |
+
video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
|
109 |
+
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
|
110 |
+
for i in video:
|
111 |
+
out.write(i)
|
112 |
+
out.release()
|
113 |
+
return path
|
114 |
+
|
115 |
+
context = Context(exp_path, device)
|
116 |
+
|
117 |
+
with gr.Blocks(css="style.css") as demo:
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
gr.Label("Cardiac Ultrasound Video Generation Demo (paper: 905)")
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
with gr.Column():
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column(scale=3, variant="panel"):
|
126 |
+
text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** ")
|
127 |
+
with gr.Column(scale=1, min_width="226"):
|
128 |
+
image = gr.Image(interactive=True)
|
129 |
+
with gr.Column(scale=1, min_width="226"):
|
130 |
+
video = gr.Video(interactive=False)
|
131 |
+
|
132 |
+
slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
|
133 |
+
slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)
|
134 |
+
|
135 |
+
with gr.Row():
|
136 |
+
img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
|
137 |
+
run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")
|
138 |
+
|
139 |
+
image.change(context.reshape_image, inputs=[image], outputs=[image])
|
140 |
+
img_btn.click(context.load_random_image, inputs=[], outputs=[image])
|
141 |
+
run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
demo.queue()
|
145 |
+
demo.launch()
|
echo_images/0X10094BA0A028EAC3.png
ADDED
![]() |
echo_images/0X1013E8A4864781B.png
ADDED
![]() |
echo_images/0X12B890B1E2E14CC4.png
ADDED
![]() |
echo_images/0X13E043A35E3EB490.png
ADDED
![]() |
echo_images/0X159BDA520C61736A.png
ADDED
![]() |
echo_images/0X15DA8D60960ABB2B.png
ADDED
![]() |
echo_images/0X16AF26F9A372EEDE.png
ADDED
![]() |
echo_images/0X17BC4EF4BF83368B.png
ADDED
![]() |
echo_images/0X1B379931357428C0.png
ADDED
![]() |
echo_images/0X1CDD9C054D8FB60D.png
ADDED
![]() |
echo_images/0X1DF7163A74801695.png
ADDED
![]() |
echo_images/0X20C397F012441121.png
ADDED
![]() |
echo_images/0X22A1A8A656653343.png
ADDED
![]() |
echo_images/0X22D7FDCF2827269E.png
ADDED
![]() |
echo_images/0X230F00FD0DF5D71C.png
ADDED
![]() |
echo_images/0X244CAB3550320216.png
ADDED
![]() |
echo_images/0X24FEF7D294B35A5B.png
ADDED
![]() |
echo_images/0X25D970C75A57B3F2.png
ADDED
![]() |
echo_images/0X277FC348812C0E79.png
ADDED
![]() |
echo_images/0X27836E538BD008A.png
ADDED
![]() |
echo_images/0X2840438B29E95F1F.png
ADDED
![]() |
echo_images/0X29A336DCE20541A0.png
ADDED
![]() |
echo_images/0X29C81728B50A2E6C.png
ADDED
![]() |
echo_images/0X2A830BC4A3A36A93.png
ADDED
![]() |
echo_images/0X2AD994F98C491FA6.png
ADDED
![]() |
echo_images/0X2BB766EF1A13DECC.png
ADDED
![]() |
echo_images/0X2DA99F9FC1DAD8A9.png
ADDED
![]() |
echo_images/0X3545F8A008B34ED0.png
ADDED
![]() |
echo_images/0X36E4468C9E659B89.png
ADDED
![]() |
echo_images/0X39CA8CC96A5D5E8B.png
ADDED
![]() |
echo_images/0X3B01B7487E3D81EA.png
ADDED
![]() |
echo_images/0X3B0D2D527C387A0E.png
ADDED
![]() |
echo_images/0X3B54A5459841DCE8.png
ADDED
![]() |
echo_images/0X3B9FBD87EE113D62.png
ADDED
![]() |
echo_images/0X3BA9F7C9DB0CF55B.png
ADDED
![]() |
echo_images/0X3DA2B290B58A6540.png
ADDED
![]() |
echo_images/0X3E2F182038897EA5.png
ADDED
![]() |
echo_images/0X3F076329C702F768.png
ADDED
![]() |
echo_images/0X4130EB4CD7ED958B.png
ADDED
![]() |
echo_images/0X42E8226CA93B7BAC.png
ADDED
![]() |
echo_images/0X45418C574D97027A.png
ADDED
![]() |
echo_images/0X45CE057EC2EB577F.png
ADDED
![]() |
echo_images/0X463A7B7D46C6CA4.png
ADDED
![]() |
echo_images/0X463C296E8E65DA97.png
ADDED
![]() |
echo_images/0X46682D67FA3FE237.png
ADDED
![]() |
echo_images/0X487B52623BC14C25.png
ADDED
![]() |