fbnnb commited on
Commit
fe89f07
·
verified ·
1 Parent(s): d722551

Update scripts/evaluation/funcs.py

Browse files
Files changed (1) hide show
  1. scripts/evaluation/funcs.py +241 -240
scripts/evaluation/funcs.py CHANGED
@@ -1,240 +1,241 @@
1
- import os, sys, glob
2
- import numpy as np
3
- from collections import OrderedDict
4
- from decord import VideoReader, cpu
5
- import cv2
6
-
7
- import torch
8
- import torchvision
9
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
- from lvdm.models.samplers.ddim import DDIMSampler
11
- from einops import rearrange
12
-
13
-
14
- def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
15
- cfg_scale=1.0, hs=None, temporal_cfg_scale=None, **kwargs):
16
- ddim_sampler = DDIMSampler(model)
17
- uncond_type = model.uncond_type
18
- batch_size = noise_shape[0]
19
- fs = cond["fs"]
20
- del cond["fs"]
21
- if noise_shape[-1] == 32:
22
- timestep_spacing = "uniform"
23
- guidance_rescale = 0.0
24
- else:
25
- timestep_spacing = "uniform_trailing"
26
- guidance_rescale = 0.7
27
- ## construct unconditional guidance
28
- if cfg_scale != 1.0:
29
- if uncond_type == "empty_seq":
30
- prompts = batch_size * [""]
31
- #prompts = N * T * [""] ## if is_imgbatch=True
32
- uc_emb = model.get_learned_conditioning(prompts)
33
- elif uncond_type == "zero_embed":
34
- c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
35
- uc_emb = torch.zeros_like(c_emb)
36
-
37
- ## process image embedding token
38
- if hasattr(model, 'embedder'):
39
- uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
40
- ## img: b c h w >> b l c
41
- uc_img = model.embedder(uc_img)
42
- uc_img = model.image_proj_model(uc_img)
43
- uc_emb = torch.cat([uc_emb, uc_img], dim=1)
44
-
45
- if isinstance(cond, dict):
46
- uc = {key:cond[key] for key in cond.keys()}
47
- uc.update({'c_crossattn': [uc_emb]})
48
- else:
49
- uc = uc_emb
50
- else:
51
- uc = None
52
-
53
-
54
- additional_decode_kwargs = {'ref_context': hs}
55
- x_T = None
56
- batch_variants = []
57
-
58
- for _ in range(n_samples):
59
- if ddim_sampler is not None:
60
- kwargs.update({"clean_cond": True})
61
- samples, _ = ddim_sampler.sample(S=ddim_steps,
62
- conditioning=cond,
63
- batch_size=noise_shape[0],
64
- shape=noise_shape[1:],
65
- verbose=False,
66
- unconditional_guidance_scale=cfg_scale,
67
- unconditional_conditioning=uc,
68
- eta=ddim_eta,
69
- temporal_length=noise_shape[2],
70
- conditional_guidance_scale_temporal=temporal_cfg_scale,
71
- x_T=x_T,
72
- fs=fs,
73
- timestep_spacing=timestep_spacing,
74
- guidance_rescale=guidance_rescale,
75
- **kwargs
76
- )
77
- ## reconstruct from latent to pixel space
78
- batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
79
-
80
- index = list(range(samples.shape[2]))
81
- del index[1]
82
- del index[-2]
83
- samples = samples[:,:,index,:,:]
84
- ## reconstruct from latent to pixel space
85
- batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
86
- batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
87
-
88
-
89
-
90
- batch_variants.append(batch_images)
91
- ## batch, <samples>, c, t, h, w
92
- batch_variants = torch.stack(batch_variants, dim=1)
93
- return batch_variants
94
-
95
-
96
- def get_filelist(data_dir, ext='*'):
97
- file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
98
- file_list.sort()
99
- return file_list
100
-
101
- def get_dirlist(path):
102
- list = []
103
- if (os.path.exists(path)):
104
- files = os.listdir(path)
105
- for file in files:
106
- m = os.path.join(path,file)
107
- if (os.path.isdir(m)):
108
- list.append(m)
109
- list.sort()
110
- return list
111
-
112
-
113
- def load_model_checkpoint(model, ckpt):
114
- def load_checkpoint(model, ckpt, full_strict):
115
- state_dict = torch.load(ckpt, map_location="cpu")
116
- if "state_dict" in list(state_dict.keys()):
117
- state_dict = state_dict["state_dict"]
118
- try:
119
- model.load_state_dict(state_dict, strict=full_strict)
120
- except:
121
- ## rename the keys for 256x256 model
122
- new_pl_sd = OrderedDict()
123
- for k,v in state_dict.items():
124
- new_pl_sd[k] = v
125
-
126
- for k in list(new_pl_sd.keys()):
127
- if "framestride_embed" in k:
128
- new_key = k.replace("framestride_embed", "fps_embedding")
129
- new_pl_sd[new_key] = new_pl_sd[k]
130
- del new_pl_sd[k]
131
- model.load_state_dict(new_pl_sd, strict=full_strict)
132
- else:
133
- ## deepspeed
134
- new_pl_sd = OrderedDict()
135
- for key in state_dict['module'].keys():
136
- new_pl_sd[key[16:]]=state_dict['module'][key]
137
- model.load_state_dict(new_pl_sd, strict=full_strict)
138
-
139
- return model
140
- load_checkpoint(model, ckpt, full_strict=True)
141
- print('>>> model checkpoint loaded.')
142
- return model
143
-
144
-
145
- def load_prompts(prompt_file):
146
- f = open(prompt_file, 'r')
147
- prompt_list = []
148
- for idx, line in enumerate(f.readlines()):
149
- l = line.strip()
150
- if len(l) != 0:
151
- prompt_list.append(l)
152
- f.close()
153
- return prompt_list
154
-
155
-
156
- def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
157
- '''
158
- Notice about some special cases:
159
- 1. video_frames=-1 means to take all the frames (with fs=1)
160
- 2. when the total video frames is less than required, padding strategy will be used (repeated last frame)
161
- '''
162
- fps_list = []
163
- batch_tensor = []
164
- assert frame_stride > 0, "valid frame stride should be a positive interge!"
165
- for filepath in filepath_list:
166
- padding_num = 0
167
- vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
168
- fps = vidreader.get_avg_fps()
169
- total_frames = len(vidreader)
170
- max_valid_frames = (total_frames-1) // frame_stride + 1
171
- if video_frames < 0:
172
- ## all frames are collected: fs=1 is a must
173
- required_frames = total_frames
174
- frame_stride = 1
175
- else:
176
- required_frames = video_frames
177
- query_frames = min(required_frames, max_valid_frames)
178
- frame_indices = [frame_stride*i for i in range(query_frames)]
179
-
180
- ## [t,h,w,c] -> [c,t,h,w]
181
- frames = vidreader.get_batch(frame_indices)
182
- frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
183
- frame_tensor = (frame_tensor / 255. - 0.5) * 2
184
- if max_valid_frames < required_frames:
185
- padding_num = required_frames - max_valid_frames
186
- frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
187
- print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
188
- batch_tensor.append(frame_tensor)
189
- sample_fps = int(fps/frame_stride)
190
- fps_list.append(sample_fps)
191
-
192
- return torch.stack(batch_tensor, dim=0)
193
-
194
- from PIL import Image
195
- def load_image_batch(filepath_list, image_size=(256,256)):
196
- batch_tensor = []
197
- for filepath in filepath_list:
198
- _, filename = os.path.split(filepath)
199
- _, ext = os.path.splitext(filename)
200
- if ext == '.mp4':
201
- vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
202
- frame = vidreader.get_batch([0])
203
- img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
204
- elif ext == '.png' or ext == '.jpg':
205
- img = Image.open(filepath).convert("RGB")
206
- rgb_img = np.array(img, np.float32)
207
- #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
208
- #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
209
- rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
210
- img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
211
- else:
212
- print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
213
- raise NotImplementedError
214
- img_tensor = (img_tensor / 255. - 0.5) * 2
215
- batch_tensor.append(img_tensor)
216
- return torch.stack(batch_tensor, dim=0)
217
-
218
-
219
- def save_videos(batch_tensors, savedir, filenames, fps=10):
220
- # b,samples,c,t,h,w
221
- n_samples = batch_tensors.shape[1]
222
- for idx, vid_tensor in enumerate(batch_tensors):
223
- video = vid_tensor.detach().cpu()
224
- video = torch.clamp(video.float(), -1., 1.)
225
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
226
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
227
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
228
- grid = (grid + 1.0) / 2.0
229
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
230
- savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
231
- torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
232
-
233
-
234
- def get_latent_z(model, videos):
235
- b, c, t, h, w = videos.shape
236
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
237
- z = model.encode_first_stage(x)
238
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
239
- return z
240
-
 
 
1
+ import os, sys, glob
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ from decord import VideoReader, cpu
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
+ from lvdm.models.samplers.ddim import DDIMSampler
11
+ from einops import rearrange
12
+
13
+
14
+ def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
15
+ cfg_scale=1.0, hs=None, temporal_cfg_scale=None, **kwargs):
16
+ ddim_sampler = DDIMSampler(model)
17
+ uncond_type = model.uncond_type
18
+ batch_size = noise_shape[0]
19
+ fs = cond["fs"]
20
+ del cond["fs"]
21
+ if noise_shape[-1] == 32:
22
+ timestep_spacing = "uniform"
23
+ guidance_rescale = 0.0
24
+ else:
25
+ timestep_spacing = "uniform_trailing"
26
+ guidance_rescale = 0.7
27
+ ## construct unconditional guidance
28
+ if cfg_scale != 1.0:
29
+ if uncond_type == "empty_seq":
30
+ prompts = batch_size * [""]
31
+ #prompts = N * T * [""] ## if is_imgbatch=True
32
+ uc_emb = model.get_learned_conditioning(prompts)
33
+ elif uncond_type == "zero_embed":
34
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
35
+ uc_emb = torch.zeros_like(c_emb)
36
+
37
+ ## process image embedding token
38
+ if hasattr(model, 'embedder'):
39
+ uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
40
+ ## img: b c h w >> b l c
41
+ uc_img = model.embedder(uc_img)
42
+ uc_img = model.image_proj_model(uc_img)
43
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
44
+
45
+ if isinstance(cond, dict):
46
+ uc = {key:cond[key] for key in cond.keys()}
47
+ uc.update({'c_crossattn': [uc_emb]})
48
+ else:
49
+ uc = uc_emb
50
+ else:
51
+ uc = None
52
+
53
+
54
+ additional_decode_kwargs = {'ref_context': hs}
55
+ x_T = None
56
+ batch_variants = []
57
+
58
+ for _ in range(n_samples):
59
+ if ddim_sampler is not None:
60
+ kwargs.update({"clean_cond": True})
61
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
62
+ conditioning=cond,
63
+ batch_size=noise_shape[0],
64
+ shape=noise_shape[1:],
65
+ verbose=False,
66
+ unconditional_guidance_scale=cfg_scale,
67
+ unconditional_conditioning=uc,
68
+ eta=ddim_eta,
69
+ temporal_length=noise_shape[2],
70
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
71
+ x_T=x_T,
72
+ fs=fs,
73
+ timestep_spacing=timestep_spacing,
74
+ guidance_rescale=guidance_rescale,
75
+ **kwargs
76
+ )
77
+ ## reconstruct from latent to pixel space
78
+ batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
79
+
80
+ index = list(range(samples.shape[2]))
81
+ del index[1]
82
+ del index[-2]
83
+ samples = samples[:,:,index,:,:]
84
+ ## reconstruct from latent to pixel space
85
+ batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
86
+ batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
87
+
88
+
89
+
90
+ batch_variants.append(batch_images)
91
+ ## batch, <samples>, c, t, h, w
92
+ batch_variants = torch.stack(batch_variants, dim=1)
93
+ return batch_variants
94
+
95
+
96
+ def get_filelist(data_dir, ext='*'):
97
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
98
+ file_list.sort()
99
+ return file_list
100
+
101
+ def get_dirlist(path):
102
+ list = []
103
+ if (os.path.exists(path)):
104
+ files = os.listdir(path)
105
+ for file in files:
106
+ m = os.path.join(path,file)
107
+ if (os.path.isdir(m)):
108
+ list.append(m)
109
+ list.sort()
110
+ return list
111
+
112
+
113
+ def load_model_checkpoint(model, ckpt):
114
+ def load_checkpoint(model, ckpt, full_strict):
115
+ state_dict = torch.load(ckpt, map_location="cpu")
116
+ if "state_dict" in list(state_dict.keys()):
117
+ state_dict = state_dict["state_dict"]
118
+ try:
119
+ model.load_state_dict(state_dict, strict=full_strict)
120
+ except:
121
+ ## rename the keys for 256x256 model
122
+ new_pl_sd = OrderedDict()
123
+ for k,v in state_dict.items():
124
+ new_pl_sd[k] = v
125
+
126
+ for k in list(new_pl_sd.keys()):
127
+ if "framestride_embed" in k:
128
+ new_key = k.replace("framestride_embed", "fps_embedding")
129
+ new_pl_sd[new_key] = new_pl_sd[k]
130
+ del new_pl_sd[k]
131
+ model.load_state_dict(new_pl_sd, strict=full_strict)
132
+ else:
133
+ ## deepspeed
134
+ new_pl_sd = OrderedDict()
135
+ for key in state_dict['module'].keys():
136
+ new_pl_sd[key[16:]]=state_dict['module'][key]
137
+ model.load_state_dict(new_pl_sd, strict=full_strict)
138
+
139
+ return model
140
+ load_checkpoint(model, ckpt, full_strict=True)
141
+ print('>>> model checkpoint loaded.')
142
+ return model
143
+
144
+
145
+ def load_prompts(prompt_file):
146
+ f = open(prompt_file, 'r')
147
+ prompt_list = []
148
+ for idx, line in enumerate(f.readlines()):
149
+ l = line.strip()
150
+ if len(l) != 0:
151
+ prompt_list.append(l)
152
+ f.close()
153
+ return prompt_list
154
+
155
+
156
+ def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
157
+ '''
158
+ Notice about some special cases:
159
+ 1. video_frames=-1 means to take all the frames (with fs=1)
160
+ 2. when the total video frames is less than required, padding strategy will be used (repeated last frame)
161
+ '''
162
+ fps_list = []
163
+ batch_tensor = []
164
+ assert frame_stride > 0, "valid frame stride should be a positive interge!"
165
+ for filepath in filepath_list:
166
+ padding_num = 0
167
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
168
+ fps = vidreader.get_avg_fps()
169
+ total_frames = len(vidreader)
170
+ max_valid_frames = (total_frames-1) // frame_stride + 1
171
+ if video_frames < 0:
172
+ ## all frames are collected: fs=1 is a must
173
+ required_frames = total_frames
174
+ frame_stride = 1
175
+ else:
176
+ required_frames = video_frames
177
+ query_frames = min(required_frames, max_valid_frames)
178
+ frame_indices = [frame_stride*i for i in range(query_frames)]
179
+
180
+ ## [t,h,w,c] -> [c,t,h,w]
181
+ frames = vidreader.get_batch(frame_indices)
182
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
183
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
184
+ if max_valid_frames < required_frames:
185
+ padding_num = required_frames - max_valid_frames
186
+ frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
187
+ print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
188
+ batch_tensor.append(frame_tensor)
189
+ sample_fps = int(fps/frame_stride)
190
+ fps_list.append(sample_fps)
191
+
192
+ return torch.stack(batch_tensor, dim=0)
193
+
194
+ from PIL import Image
195
+ def load_image_batch(filepath_list, image_size=(256,256)):
196
+ batch_tensor = []
197
+ for filepath in filepath_list:
198
+ _, filename = os.path.split(filepath)
199
+ _, ext = os.path.splitext(filename)
200
+ if ext == '.mp4':
201
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
202
+ frame = vidreader.get_batch([0])
203
+ img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
204
+ elif ext == '.png' or ext == '.jpg':
205
+ img = Image.open(filepath).convert("RGB")
206
+ rgb_img = np.array(img, np.float32)
207
+ #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
208
+ #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
209
+ rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
210
+ img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
211
+ else:
212
+ print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
213
+ raise NotImplementedError
214
+ img_tensor = (img_tensor / 255. - 0.5) * 2
215
+ batch_tensor.append(img_tensor)
216
+ return torch.stack(batch_tensor, dim=0)
217
+
218
+
219
+ def save_videos(batch_tensors, savedir, filenames, fps=10):
220
+ # b,samples,c,t,h,w
221
+ n_samples = batch_tensors.shape[1]
222
+ for idx, vid_tensor in enumerate(batch_tensors):
223
+ video = vid_tensor.detach().cpu()
224
+ video = torch.clamp(video.float(), -1., 1.)
225
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
226
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
227
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
228
+ grid = (grid + 1.0) / 2.0
229
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
230
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
231
+ print("saving path:", savepath)
232
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
233
+
234
+
235
+ def get_latent_z(model, videos):
236
+ b, c, t, h, w = videos.shape
237
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
238
+ z = model.encode_first_stage(x)
239
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
240
+ return z
241
+