fbnnb commited on
Commit
da151b2
Β·
verified Β·
1 Parent(s): a8477f8

Update scripts/evaluation/funcs.py

Browse files
Files changed (1) hide show
  1. scripts/evaluation/funcs.py +250 -248
scripts/evaluation/funcs.py CHANGED
@@ -1,248 +1,250 @@
1
- import os, sys, glob
2
- import numpy as np
3
- from collections import OrderedDict
4
- from decord import VideoReader, cpu
5
- import cv2
6
-
7
- import torch
8
- import torchvision
9
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
- from lvdm.models.samplers.ddim import DDIMSampler
11
- from einops import rearrange
12
-
13
-
14
- @torch.no_grad()
15
- def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
16
- cfg_scale=1.0, hs=None, temporal_cfg_scale=None, **kwargs):
17
- ddim_sampler = DDIMSampler(model)
18
- uncond_type = model.uncond_type
19
- batch_size = noise_shape[0]
20
- fs = cond["fs"]
21
- del cond["fs"]
22
- if noise_shape[-1] == 32:
23
- timestep_spacing = "uniform"
24
- guidance_rescale = 0.0
25
- else:
26
- timestep_spacing = "uniform_trailing"
27
- guidance_rescale = 0.7
28
- ## construct unconditional guidance
29
- if cfg_scale != 1.0:
30
- if uncond_type == "empty_seq":
31
- prompts = batch_size * [""]
32
- #prompts = N * T * [""] ## if is_imgbatch=True
33
- uc_emb = model.get_learned_conditioning(prompts)
34
- elif uncond_type == "zero_embed":
35
- c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
36
- uc_emb = torch.zeros_like(c_emb)
37
-
38
- ## process image embedding token
39
- if hasattr(model, 'embedder'):
40
- uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
41
- ## img: b c h w >> b l c
42
- uc_img = model.embedder(uc_img)
43
- uc_img = model.image_proj_model(uc_img)
44
- uc_emb = torch.cat([uc_emb, uc_img], dim=1)
45
-
46
- if isinstance(cond, dict):
47
- uc = {key:cond[key] for key in cond.keys()}
48
- uc.update({'c_crossattn': [uc_emb]})
49
- else:
50
- uc = uc_emb
51
- else:
52
- uc = None
53
-
54
-
55
- additional_decode_kwargs = {'ref_context': hs}
56
- x_T = None
57
- batch_variants = []
58
-
59
- for _ in range(n_samples):
60
- if ddim_sampler is not None:
61
- kwargs.update({"clean_cond": True})
62
- samples, _ = ddim_sampler.sample(S=ddim_steps,
63
- conditioning=cond,
64
- batch_size=noise_shape[0],
65
- shape=noise_shape[1:],
66
- verbose=False,
67
- unconditional_guidance_scale=cfg_scale,
68
- unconditional_conditioning=uc,
69
- eta=ddim_eta,
70
- temporal_length=noise_shape[2],
71
- conditional_guidance_scale_temporal=temporal_cfg_scale,
72
- x_T=x_T,
73
- fs=fs,
74
- timestep_spacing=timestep_spacing,
75
- guidance_rescale=guidance_rescale,
76
- **kwargs
77
- )
78
-
79
- torch.cuda.empty_cache()
80
- model.cpu()
81
- model.first_stage_model.cuda()
82
-
83
- ## reconstruct from latent to pixel space
84
- batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
85
-
86
- index = list(range(samples.shape[2]))
87
- del index[1]
88
- del index[-2]
89
- samples = samples[:,:,index,:,:]
90
- ## reconstruct from latent to pixel space
91
- batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
92
- batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
93
-
94
-
95
-
96
- batch_variants.append(batch_images)
97
- ## batch, <samples>, c, t, h, w
98
- batch_variants = torch.stack(batch_variants, dim=1)
99
- return batch_variants
100
-
101
-
102
- def get_filelist(data_dir, ext='*'):
103
- file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
104
- file_list.sort()
105
- return file_list
106
-
107
- def get_dirlist(path):
108
- list = []
109
- if (os.path.exists(path)):
110
- files = os.listdir(path)
111
- for file in files:
112
- m = os.path.join(path,file)
113
- if (os.path.isdir(m)):
114
- list.append(m)
115
- list.sort()
116
- return list
117
-
118
-
119
- def load_model_checkpoint(model, ckpt):
120
- def load_checkpoint(model, ckpt, full_strict):
121
- state_dict = torch.load(ckpt, map_location="cpu")
122
-
123
- if "state_dict" in list(state_dict.keys()):
124
- state_dict = state_dict["state_dict"]
125
- try:
126
- model.load_state_dict(state_dict, strict=full_strict)
127
- except:
128
- ## rename the keys for 256x256 model
129
- new_pl_sd = OrderedDict()
130
- for k,v in state_dict.items():
131
- new_pl_sd[k] = v
132
-
133
- for k in list(new_pl_sd.keys()):
134
- if "framestride_embed" in k:
135
- new_key = k.replace("framestride_embed", "fps_embedding")
136
- new_pl_sd[new_key] = new_pl_sd[k]
137
- del new_pl_sd[k]
138
- model.load_state_dict(new_pl_sd, strict=full_strict)
139
- elif 'module' in state_dict:
140
- ## deepspeed
141
- new_pl_sd = OrderedDict()
142
- for key in state_dict['module'].keys():
143
- new_pl_sd[key[16:]]=state_dict['module'][key]
144
- model.load_state_dict(new_pl_sd, strict=full_strict)
145
- else:
146
- model.load_state_dict(state_dict, strict=True)
147
- return model
148
- load_checkpoint(model, ckpt, full_strict=True)
149
- print('>>> model checkpoint loaded.')
150
- return model
151
-
152
-
153
- def load_prompts(prompt_file):
154
- f = open(prompt_file, 'r')
155
- prompt_list = []
156
- for idx, line in enumerate(f.readlines()):
157
- l = line.strip()
158
- if len(l) != 0:
159
- prompt_list.append(l)
160
- f.close()
161
- return prompt_list
162
-
163
-
164
- def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
165
- '''
166
- Notice about some special cases:
167
- 1. video_frames=-1 means to take all the frames (with fs=1)
168
- 2. when the total video frames is less than required, padding strategy will be used (repeated last frame)
169
- '''
170
- fps_list = []
171
- batch_tensor = []
172
- assert frame_stride > 0, "valid frame stride should be a positive interge!"
173
- for filepath in filepath_list:
174
- padding_num = 0
175
- vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
176
- fps = vidreader.get_avg_fps()
177
- total_frames = len(vidreader)
178
- max_valid_frames = (total_frames-1) // frame_stride + 1
179
- if video_frames < 0:
180
- ## all frames are collected: fs=1 is a must
181
- required_frames = total_frames
182
- frame_stride = 1
183
- else:
184
- required_frames = video_frames
185
- query_frames = min(required_frames, max_valid_frames)
186
- frame_indices = [frame_stride*i for i in range(query_frames)]
187
-
188
- ## [t,h,w,c] -> [c,t,h,w]
189
- frames = vidreader.get_batch(frame_indices)
190
- frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
191
- frame_tensor = (frame_tensor / 255. - 0.5) * 2
192
- if max_valid_frames < required_frames:
193
- padding_num = required_frames - max_valid_frames
194
- frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
195
- print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
196
- batch_tensor.append(frame_tensor)
197
- sample_fps = int(fps/frame_stride)
198
- fps_list.append(sample_fps)
199
-
200
- return torch.stack(batch_tensor, dim=0)
201
-
202
- from PIL import Image
203
- def load_image_batch(filepath_list, image_size=(256,256)):
204
- batch_tensor = []
205
- for filepath in filepath_list:
206
- _, filename = os.path.split(filepath)
207
- _, ext = os.path.splitext(filename)
208
- if ext == '.mp4':
209
- vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
210
- frame = vidreader.get_batch([0])
211
- img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
212
- elif ext == '.png' or ext == '.jpg':
213
- img = Image.open(filepath).convert("RGB")
214
- rgb_img = np.array(img, np.float32)
215
- #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
216
- #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
217
- rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
218
- img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
219
- else:
220
- print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
221
- raise NotImplementedError
222
- img_tensor = (img_tensor / 255. - 0.5) * 2
223
- batch_tensor.append(img_tensor)
224
- return torch.stack(batch_tensor, dim=0)
225
-
226
-
227
- def save_videos(batch_tensors, savedir, filenames, fps=10):
228
- # b,samples,c,t,h,w
229
- n_samples = batch_tensors.shape[1]
230
- for idx, vid_tensor in enumerate(batch_tensors):
231
- video = vid_tensor.detach().cpu()
232
- video = torch.clamp(video.float(), -1., 1.)
233
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
234
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
235
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
236
- grid = (grid + 1.0) / 2.0
237
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
238
- savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
239
- torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
240
-
241
-
242
- def get_latent_z(model, videos):
243
- b, c, t, h, w = videos.shape
244
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
245
- z = model.encode_first_stage(x)
246
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
247
- return z
248
-
 
 
 
1
+ import os, sys, glob
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ from decord import VideoReader, cpu
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
+ from lvdm.models.samplers.ddim import DDIMSampler
11
+ from einops import rearrange
12
+
13
+
14
+ @torch.no_grad()
15
+ def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
16
+ cfg_scale=1.0, hs=None, temporal_cfg_scale=None, **kwargs):
17
+ ddim_sampler = DDIMSampler(model)
18
+ uncond_type = model.uncond_type
19
+ batch_size = noise_shape[0]
20
+ fs = cond["fs"]
21
+ del cond["fs"]
22
+ if noise_shape[-1] == 32:
23
+ timestep_spacing = "uniform"
24
+ guidance_rescale = 0.0
25
+ else:
26
+ timestep_spacing = "uniform_trailing"
27
+ guidance_rescale = 0.7
28
+ ## construct unconditional guidance
29
+ if cfg_scale != 1.0:
30
+ if uncond_type == "empty_seq":
31
+ prompts = batch_size * [""]
32
+ #prompts = N * T * [""] ## if is_imgbatch=True
33
+ uc_emb = model.get_learned_conditioning(prompts)
34
+ elif uncond_type == "zero_embed":
35
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
36
+ uc_emb = torch.zeros_like(c_emb)
37
+
38
+ ## process image embedding token
39
+ if hasattr(model, 'embedder'):
40
+ uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
41
+ ## img: b c h w >> b l c
42
+ uc_img = model.embedder(uc_img)
43
+ uc_img = model.image_proj_model(uc_img)
44
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
45
+
46
+ if isinstance(cond, dict):
47
+ uc = {key:cond[key] for key in cond.keys()}
48
+ uc.update({'c_crossattn': [uc_emb]})
49
+ else:
50
+ uc = uc_emb
51
+ else:
52
+ uc = None
53
+
54
+
55
+ additional_decode_kwargs = {'ref_context': hs}
56
+ x_T = None
57
+ batch_variants = []
58
+
59
+ for _ in range(n_samples):
60
+ if ddim_sampler is not None:
61
+ kwargs.update({"clean_cond": True})
62
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
63
+ conditioning=cond,
64
+ batch_size=noise_shape[0],
65
+ shape=noise_shape[1:],
66
+ verbose=False,
67
+ unconditional_guidance_scale=cfg_scale,
68
+ unconditional_conditioning=uc,
69
+ eta=ddim_eta,
70
+ temporal_length=noise_shape[2],
71
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
72
+ x_T=x_T,
73
+ fs=fs,
74
+ timestep_spacing=timestep_spacing,
75
+ guidance_rescale=guidance_rescale,
76
+ **kwargs
77
+ )
78
+
79
+ torch.cuda.empty_cache()
80
+ model.cpu()
81
+ model.first_stage_model.cuda()
82
+ for k, v in model.named_parameters():
83
+ print(k, v.device)
84
+
85
+ ## reconstruct from latent to pixel space
86
+ batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
87
+
88
+ index = list(range(samples.shape[2]))
89
+ del index[1]
90
+ del index[-2]
91
+ samples = samples[:,:,index,:,:]
92
+ ## reconstruct from latent to pixel space
93
+ batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
94
+ batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
95
+
96
+
97
+
98
+ batch_variants.append(batch_images)
99
+ ## batch, <samples>, c, t, h, w
100
+ batch_variants = torch.stack(batch_variants, dim=1)
101
+ return batch_variants
102
+
103
+
104
+ def get_filelist(data_dir, ext='*'):
105
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
106
+ file_list.sort()
107
+ return file_list
108
+
109
+ def get_dirlist(path):
110
+ list = []
111
+ if (os.path.exists(path)):
112
+ files = os.listdir(path)
113
+ for file in files:
114
+ m = os.path.join(path,file)
115
+ if (os.path.isdir(m)):
116
+ list.append(m)
117
+ list.sort()
118
+ return list
119
+
120
+
121
+ def load_model_checkpoint(model, ckpt):
122
+ def load_checkpoint(model, ckpt, full_strict):
123
+ state_dict = torch.load(ckpt, map_location="cpu")
124
+
125
+ if "state_dict" in list(state_dict.keys()):
126
+ state_dict = state_dict["state_dict"]
127
+ try:
128
+ model.load_state_dict(state_dict, strict=full_strict)
129
+ except:
130
+ ## rename the keys for 256x256 model
131
+ new_pl_sd = OrderedDict()
132
+ for k,v in state_dict.items():
133
+ new_pl_sd[k] = v
134
+
135
+ for k in list(new_pl_sd.keys()):
136
+ if "framestride_embed" in k:
137
+ new_key = k.replace("framestride_embed", "fps_embedding")
138
+ new_pl_sd[new_key] = new_pl_sd[k]
139
+ del new_pl_sd[k]
140
+ model.load_state_dict(new_pl_sd, strict=full_strict)
141
+ elif 'module' in state_dict:
142
+ ## deepspeed
143
+ new_pl_sd = OrderedDict()
144
+ for key in state_dict['module'].keys():
145
+ new_pl_sd[key[16:]]=state_dict['module'][key]
146
+ model.load_state_dict(new_pl_sd, strict=full_strict)
147
+ else:
148
+ model.load_state_dict(state_dict, strict=True)
149
+ return model
150
+ load_checkpoint(model, ckpt, full_strict=True)
151
+ print('>>> model checkpoint loaded.')
152
+ return model
153
+
154
+
155
+ def load_prompts(prompt_file):
156
+ f = open(prompt_file, 'r')
157
+ prompt_list = []
158
+ for idx, line in enumerate(f.readlines()):
159
+ l = line.strip()
160
+ if len(l) != 0:
161
+ prompt_list.append(l)
162
+ f.close()
163
+ return prompt_list
164
+
165
+
166
+ def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
167
+ '''
168
+ Notice about some special cases:
169
+ 1. video_frames=-1 means to take all the frames (with fs=1)
170
+ 2. when the total video frames is less than required, padding strategy will be used (repeated last frame)
171
+ '''
172
+ fps_list = []
173
+ batch_tensor = []
174
+ assert frame_stride > 0, "valid frame stride should be a positive interge!"
175
+ for filepath in filepath_list:
176
+ padding_num = 0
177
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
178
+ fps = vidreader.get_avg_fps()
179
+ total_frames = len(vidreader)
180
+ max_valid_frames = (total_frames-1) // frame_stride + 1
181
+ if video_frames < 0:
182
+ ## all frames are collected: fs=1 is a must
183
+ required_frames = total_frames
184
+ frame_stride = 1
185
+ else:
186
+ required_frames = video_frames
187
+ query_frames = min(required_frames, max_valid_frames)
188
+ frame_indices = [frame_stride*i for i in range(query_frames)]
189
+
190
+ ## [t,h,w,c] -> [c,t,h,w]
191
+ frames = vidreader.get_batch(frame_indices)
192
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
193
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
194
+ if max_valid_frames < required_frames:
195
+ padding_num = required_frames - max_valid_frames
196
+ frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
197
+ print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
198
+ batch_tensor.append(frame_tensor)
199
+ sample_fps = int(fps/frame_stride)
200
+ fps_list.append(sample_fps)
201
+
202
+ return torch.stack(batch_tensor, dim=0)
203
+
204
+ from PIL import Image
205
+ def load_image_batch(filepath_list, image_size=(256,256)):
206
+ batch_tensor = []
207
+ for filepath in filepath_list:
208
+ _, filename = os.path.split(filepath)
209
+ _, ext = os.path.splitext(filename)
210
+ if ext == '.mp4':
211
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
212
+ frame = vidreader.get_batch([0])
213
+ img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
214
+ elif ext == '.png' or ext == '.jpg':
215
+ img = Image.open(filepath).convert("RGB")
216
+ rgb_img = np.array(img, np.float32)
217
+ #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
218
+ #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
219
+ rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
220
+ img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
221
+ else:
222
+ print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
223
+ raise NotImplementedError
224
+ img_tensor = (img_tensor / 255. - 0.5) * 2
225
+ batch_tensor.append(img_tensor)
226
+ return torch.stack(batch_tensor, dim=0)
227
+
228
+
229
+ def save_videos(batch_tensors, savedir, filenames, fps=10):
230
+ # b,samples,c,t,h,w
231
+ n_samples = batch_tensors.shape[1]
232
+ for idx, vid_tensor in enumerate(batch_tensors):
233
+ video = vid_tensor.detach().cpu()
234
+ video = torch.clamp(video.float(), -1., 1.)
235
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
236
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
237
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
238
+ grid = (grid + 1.0) / 2.0
239
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
240
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
241
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
242
+
243
+
244
+ def get_latent_z(model, videos):
245
+ b, c, t, h, w = videos.shape
246
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
247
+ z = model.encode_first_stage(x)
248
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
249
+ return z
250
+