HReynaud Anonymous commited on
Commit
316f1d5
·
0 Parent(s):

Duplicate from anon-SGXT/echocardiogram-video-diffusion

Browse files

Co-authored-by: Anonymous <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. .gitignore +3 -0
  3. README.md +14 -0
  4. app.py +145 -0
  5. echo_images/0X10094BA0A028EAC3.png +0 -0
  6. echo_images/0X1013E8A4864781B.png +0 -0
  7. echo_images/0X12B890B1E2E14CC4.png +0 -0
  8. echo_images/0X13E043A35E3EB490.png +0 -0
  9. echo_images/0X159BDA520C61736A.png +0 -0
  10. echo_images/0X15DA8D60960ABB2B.png +0 -0
  11. echo_images/0X16AF26F9A372EEDE.png +0 -0
  12. echo_images/0X17BC4EF4BF83368B.png +0 -0
  13. echo_images/0X1B379931357428C0.png +0 -0
  14. echo_images/0X1CDD9C054D8FB60D.png +0 -0
  15. echo_images/0X1DF7163A74801695.png +0 -0
  16. echo_images/0X20C397F012441121.png +0 -0
  17. echo_images/0X22A1A8A656653343.png +0 -0
  18. echo_images/0X22D7FDCF2827269E.png +0 -0
  19. echo_images/0X230F00FD0DF5D71C.png +0 -0
  20. echo_images/0X244CAB3550320216.png +0 -0
  21. echo_images/0X24FEF7D294B35A5B.png +0 -0
  22. echo_images/0X25D970C75A57B3F2.png +0 -0
  23. echo_images/0X277FC348812C0E79.png +0 -0
  24. echo_images/0X27836E538BD008A.png +0 -0
  25. echo_images/0X2840438B29E95F1F.png +0 -0
  26. echo_images/0X29A336DCE20541A0.png +0 -0
  27. echo_images/0X29C81728B50A2E6C.png +0 -0
  28. echo_images/0X2A830BC4A3A36A93.png +0 -0
  29. echo_images/0X2AD994F98C491FA6.png +0 -0
  30. echo_images/0X2BB766EF1A13DECC.png +0 -0
  31. echo_images/0X2DA99F9FC1DAD8A9.png +0 -0
  32. echo_images/0X3545F8A008B34ED0.png +0 -0
  33. echo_images/0X36E4468C9E659B89.png +0 -0
  34. echo_images/0X39CA8CC96A5D5E8B.png +0 -0
  35. echo_images/0X3B01B7487E3D81EA.png +0 -0
  36. echo_images/0X3B0D2D527C387A0E.png +0 -0
  37. echo_images/0X3B54A5459841DCE8.png +0 -0
  38. echo_images/0X3B9FBD87EE113D62.png +0 -0
  39. echo_images/0X3BA9F7C9DB0CF55B.png +0 -0
  40. echo_images/0X3DA2B290B58A6540.png +0 -0
  41. echo_images/0X3E2F182038897EA5.png +0 -0
  42. echo_images/0X3F076329C702F768.png +0 -0
  43. echo_images/0X4130EB4CD7ED958B.png +0 -0
  44. echo_images/0X42E8226CA93B7BAC.png +0 -0
  45. echo_images/0X45418C574D97027A.png +0 -0
  46. echo_images/0X45CE057EC2EB577F.png +0 -0
  47. echo_images/0X463A7B7D46C6CA4.png +0 -0
  48. echo_images/0X463C296E8E65DA97.png +0 -0
  49. echo_images/0X46682D67FA3FE237.png +0 -0
  50. echo_images/0X487B52623BC14C25.png +0 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.mp4
2
+ *.ipynb
3
+ *__pycache__*
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EchoNet Video Diffusion
3
+ emoji: 🖤
4
+ colorFrom: gray
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.17.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: anon-SGXT/echocardiogram-video-diffusion
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from omegaconf import OmegaConf
4
+ from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image
9
+ import torchvision.transforms as T
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ exp_path = "model"
14
+
15
+ class BetterCenterCrop(T.CenterCrop):
16
+ def __call__(self, img):
17
+ h = img.shape[-2]
18
+ w = img.shape[-1]
19
+ dim = min(h, w)
20
+
21
+ return T.functional.center_crop(img, dim)
22
+
23
+ class ImageLoader:
24
+ def __init__(self, path) -> None:
25
+ self.path = path
26
+ self.all_files = os.listdir(path)
27
+ self.transform = T.Compose([
28
+ T.ToTensor(),
29
+ BetterCenterCrop((112, 112)),
30
+ T.Resize((112, 112)),
31
+ ])
32
+
33
+ def get_image(self):
34
+ idx = np.random.randint(0, len(self.all_files))
35
+ img = Image.open(os.path.join(self.path, self.all_files[idx]))
36
+ return img
37
+
38
+ class Context:
39
+ def __init__(self, path, device):
40
+ self.path = path
41
+ self.config_path = os.path.join(path, "config.yaml")
42
+ self.weight_path = os.path.join(path, "merged.pt")
43
+
44
+ self.config = OmegaConf.load(self.config_path)
45
+
46
+ self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)
47
+
48
+ self.im_load = ImageLoader("echo_images")
49
+
50
+ unets = []
51
+ for i, (k, v) in enumerate(self.config.unets.items()):
52
+ unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore
53
+
54
+ imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
55
+ del self.config.imagen.elucidated
56
+ imagen = imagen_klass(
57
+ unets = unets,
58
+ **OmegaConf.to_container(self.config.imagen), # type: ignore
59
+ )
60
+
61
+ self.trainer = ImagenTrainer(
62
+ imagen = imagen,
63
+ **self.config.trainer
64
+ ).to(device)
65
+
66
+ print("Loading weights from", self.weight_path)
67
+ additional_data = self.trainer.load(self.weight_path)
68
+ print("Loaded weights from", self.weight_path)
69
+
70
+ def reshape_image(self, image):
71
+ try:
72
+ image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
73
+ return image
74
+ except:
75
+ return None
76
+
77
+ def load_random_image(self):
78
+ print("Loading random image")
79
+ image = self.im_load.get_image()
80
+ return image
81
+
82
+ def generate_video(self, image, lvef, cond_scale):
83
+ print("Generating video")
84
+ print(f"lvef: {lvef}, cond_scale: {cond_scale}")
85
+
86
+ image = self.im_load.transform(image).unsqueeze(0)
87
+
88
+ sample_kwargs = {}
89
+ sample_kwargs = {
90
+ "text_embeds": torch.tensor([[[lvef/100.0]]]),
91
+ "cond_scale": cond_scale,
92
+ "cond_images": image,
93
+ }
94
+
95
+ self.trainer.eval()
96
+ with torch.no_grad():
97
+ video = self.trainer.sample(
98
+ batch_size=1,
99
+ video_frames=self.config.dataset.num_frames,
100
+ **sample_kwargs,
101
+ use_tqdm = True,
102
+ ).detach().cpu() # C x F x H x W
103
+ if video.shape[-3:] != (64, 112, 112):
104
+ video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
105
+ video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
106
+ uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
107
+ path = f"tmp/{uid}.mp4"
108
+ video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
109
+ out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
110
+ for i in video:
111
+ out.write(i)
112
+ out.release()
113
+ return path
114
+
115
+ context = Context(exp_path, device)
116
+
117
+ with gr.Blocks(css="style.css") as demo:
118
+
119
+ with gr.Row():
120
+ gr.Label("Cardiac Ultrasound Video Generation Demo (paper: 905)")
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ with gr.Row():
125
+ with gr.Column(scale=3, variant="panel"):
126
+ text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** ")
127
+ with gr.Column(scale=1, min_width="226"):
128
+ image = gr.Image(interactive=True)
129
+ with gr.Column(scale=1, min_width="226"):
130
+ video = gr.Video(interactive=False)
131
+
132
+ slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
133
+ slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)
134
+
135
+ with gr.Row():
136
+ img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
137
+ run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")
138
+
139
+ image.change(context.reshape_image, inputs=[image], outputs=[image])
140
+ img_btn.click(context.load_random_image, inputs=[], outputs=[image])
141
+ run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
142
+
143
+ if __name__ == "__main__":
144
+ demo.queue()
145
+ demo.launch()
echo_images/0X10094BA0A028EAC3.png ADDED
echo_images/0X1013E8A4864781B.png ADDED
echo_images/0X12B890B1E2E14CC4.png ADDED
echo_images/0X13E043A35E3EB490.png ADDED
echo_images/0X159BDA520C61736A.png ADDED
echo_images/0X15DA8D60960ABB2B.png ADDED
echo_images/0X16AF26F9A372EEDE.png ADDED
echo_images/0X17BC4EF4BF83368B.png ADDED
echo_images/0X1B379931357428C0.png ADDED
echo_images/0X1CDD9C054D8FB60D.png ADDED
echo_images/0X1DF7163A74801695.png ADDED
echo_images/0X20C397F012441121.png ADDED
echo_images/0X22A1A8A656653343.png ADDED
echo_images/0X22D7FDCF2827269E.png ADDED
echo_images/0X230F00FD0DF5D71C.png ADDED
echo_images/0X244CAB3550320216.png ADDED
echo_images/0X24FEF7D294B35A5B.png ADDED
echo_images/0X25D970C75A57B3F2.png ADDED
echo_images/0X277FC348812C0E79.png ADDED
echo_images/0X27836E538BD008A.png ADDED
echo_images/0X2840438B29E95F1F.png ADDED
echo_images/0X29A336DCE20541A0.png ADDED
echo_images/0X29C81728B50A2E6C.png ADDED
echo_images/0X2A830BC4A3A36A93.png ADDED
echo_images/0X2AD994F98C491FA6.png ADDED
echo_images/0X2BB766EF1A13DECC.png ADDED
echo_images/0X2DA99F9FC1DAD8A9.png ADDED
echo_images/0X3545F8A008B34ED0.png ADDED
echo_images/0X36E4468C9E659B89.png ADDED
echo_images/0X39CA8CC96A5D5E8B.png ADDED
echo_images/0X3B01B7487E3D81EA.png ADDED
echo_images/0X3B0D2D527C387A0E.png ADDED
echo_images/0X3B54A5459841DCE8.png ADDED
echo_images/0X3B9FBD87EE113D62.png ADDED
echo_images/0X3BA9F7C9DB0CF55B.png ADDED
echo_images/0X3DA2B290B58A6540.png ADDED
echo_images/0X3E2F182038897EA5.png ADDED
echo_images/0X3F076329C702F768.png ADDED
echo_images/0X4130EB4CD7ED958B.png ADDED
echo_images/0X42E8226CA93B7BAC.png ADDED
echo_images/0X45418C574D97027A.png ADDED
echo_images/0X45CE057EC2EB577F.png ADDED
echo_images/0X463A7B7D46C6CA4.png ADDED
echo_images/0X463C296E8E65DA97.png ADDED
echo_images/0X46682D67FA3FE237.png ADDED
echo_images/0X487B52623BC14C25.png ADDED