rodrigomasini commited on
Commit
6f2ec28
·
verified ·
1 Parent(s): 3805ac7

Create calculations.py

Browse files
Files changed (1) hide show
  1. calculations.py +381 -0
calculations.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # List of requirements
2
+ # torch~=1.13
3
+ # torchvision
4
+ # opencv-python
5
+ # scipy
6
+ # numpy
7
+ # tqdm
8
+ # timm
9
+ # einops
10
+ # scikit-video
11
+ # pillow
12
+ # logger
13
+ # diffusers
14
+ # transformers
15
+ # accelerate
16
+ # requests
17
+ # pycocoevalcap
18
+
19
+ import os
20
+ import torch
21
+ import cv2
22
+ import numpy as np
23
+ from PIL import Image
24
+ from transformers import CLIPProcessor, CLIPModel, AutoTokenizer
25
+ import time
26
+ import logging
27
+ from tqdm import tqdm
28
+ import argparse
29
+ import torchvision.transforms as transforms
30
+ from torchvision.transforms import Resize
31
+ from torchvision.utils import save_image
32
+ from diffusers import StableDiffusionXLPipeline
33
+ import requests
34
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
35
+ import ipdb
36
+ from pycocoevalcap.cider.cider import Cider
37
+ from pycocoevalcap.bleu.bleu import Bleu
38
+
39
+ def calculate_clip_score(video_path, text, model, tokenizer):
40
+ # Load the video
41
+ cap = cv2.VideoCapture(video_path)
42
+
43
+ # Extract frames from the video
44
+ frames = []
45
+
46
+ while cap.isOpened():
47
+ ret, frame = cap.read()
48
+ if not ret:
49
+ break
50
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
51
+ resized_frame = cv2.resize(frame,(224,224)) # Resize the frame to match the expected input size
52
+ frames.append(resized_frame)
53
+
54
+ # Convert numpy arrays to tensors, change dtype to float, and resize frames
55
+ tensor_frames = [torch.from_numpy(frame).permute(2, 0, 1).float() for frame in frames]
56
+
57
+ # Initialize an empty tensor to store the concatenated features
58
+ concatenated_features = torch.tensor([], device=device)
59
+
60
+ # Generate embeddings for each frame and concatenate the features
61
+ with torch.no_grad():
62
+ for frame in tensor_frames:
63
+ frame_input = frame.unsqueeze(0).to(device) # Add batch dimension and move the frame to the device
64
+ frame_features = model.get_image_features(frame_input)
65
+ concatenated_features = torch.cat((concatenated_features, frame_features), dim=0)
66
+
67
+ # Tokenize the text
68
+ text_tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77)
69
+
70
+ # Convert the tokenized text to a tensor and move it to the device
71
+ text_input = text_tokens["input_ids"].to(device)
72
+
73
+ # Generate text embeddings
74
+ with torch.no_grad():
75
+ text_features = model.get_text_features(text_input)
76
+
77
+ # Calculate the cosine similarity scores
78
+ concatenated_features = concatenated_features / concatenated_features.norm(p=2, dim=-1, keepdim=True)
79
+ text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
80
+ clip_score_frames = concatenated_features @ text_features.T
81
+ # Calculate the average CLIP score across all frames, reflects temporal consistency
82
+ clip_score_frames_avg = clip_score_frames.mean().item()
83
+
84
+ return clip_score_frames_avg
85
+
86
+ def calculate_clip_temp_score(video_path, model):
87
+ # Load the video
88
+ cap = cv2.VideoCapture(video_path)
89
+ to_tensor = transforms.ToTensor()
90
+ # Extract frames from the video
91
+ frames = []
92
+ SD_images = []
93
+ resize = transforms.Resize([224,224])
94
+ while cap.isOpened():
95
+ ret, frame = cap.read()
96
+ if not ret:
97
+ break
98
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
99
+ # resized_frame = cv2.resize(frame,(224,224)) # Resize the frame to match the expected input size
100
+ frames.append(frame)
101
+
102
+ tensor_frames = torch.stack([resize(torch.from_numpy(frame).permute(2, 0, 1).float()) for frame in frames])
103
+
104
+ # tensor_frames = [extracted_frames[i] for i in range(extracted_frames.size()[0])]
105
+ concatenated_frame_features = []
106
+
107
+ # Generate embeddings for each frame and concatenate the features
108
+ with torch.no_grad():
109
+ for frame in tensor_frames: # Too many frames in a video, must split before CLIP embedding, limited by the memory
110
+ frame_input = frame.unsqueeze(0).to(device) # Add batch dimension and move the frame to the device
111
+ frame_feature = model.get_image_features(frame_input)
112
+ concatenated_frame_features.append(frame_feature)
113
+
114
+ concatenated_frame_features = torch.cat(concatenated_frame_features, dim=0)
115
+
116
+ # Calculate the similarity scores
117
+ clip_temp_score = []
118
+ concatenated_frame_features = concatenated_frame_features / concatenated_frame_features.norm(p=2, dim=-1, keepdim=True)
119
+ # ipdb.set_trace()
120
+
121
+ for i in range(concatenated_frame_features.size()[0]-1):
122
+ clip_temp_score.append(concatenated_frame_features[i].unsqueeze(0) @ concatenated_frame_features[i+1].unsqueeze(0).T)
123
+ clip_temp_score=torch.cat(clip_temp_score, dim=0)
124
+ # Calculate the average CLIP score across all frames, reflects temporal consistency
125
+ clip_temp_score_avg = clip_temp_score.mean().item()
126
+
127
+ return clip_temp_score_avg
128
+
129
+ def compute_max(scorer, gt_prompts, pred_prompts):
130
+ scores = []
131
+ for pred_prompt in pred_prompts:
132
+ for gt_prompt in gt_prompts:
133
+ cand = {0: [pred_prompt]}
134
+ ref = {0: [gt_prompt]}
135
+ score, _ = scorer.compute_score(ref, cand)
136
+ scores.append(score)
137
+ return np.max(scores)
138
+
139
+ def calculate_blip_bleu(video_path, original_text, blip2_model, blip2_processor):
140
+ # Load the video
141
+ cap = cv2.VideoCapture(video_path)
142
+
143
+ scorer_cider = Cider()
144
+ bleu1 = Bleu(n=1)
145
+ bleu2 = Bleu(n=2)
146
+ bleu3 = Bleu(n=3)
147
+ bleu4 = Bleu(n=4)
148
+
149
+ # Extract frames from the video
150
+ frames = []
151
+ while cap.isOpened():
152
+ ret, frame = cap.read()
153
+ if not ret:
154
+ break
155
+ resized_frame = cv2.resize(frame,(224,224)) # Resize the frame to match the expected input size
156
+ frames.append(resized_frame)
157
+
158
+ # Convert numpy arrays to tensors, change dtype to float, and resize frames
159
+ tensor_frames = torch.stack([torch.from_numpy(frame).permute(2, 0, 1).float() for frame in frames])
160
+ # Get five captions for one video
161
+ Num = 5
162
+ captions = []
163
+ # for i in range(Num):
164
+ N = len(tensor_frames)
165
+ indices = torch.linspace(0, N - 1, Num).long()
166
+ extracted_frames = torch.index_select(tensor_frames, 0, indices)
167
+ for i in range(Num):
168
+ frame = extracted_frames[i]
169
+ inputs = blip2_processor(images=frame, return_tensors="pt").to(device, torch.float16)
170
+ generated_ids = blip2_model.generate(**inputs)
171
+ generated_text = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
172
+ captions.append(generated_text)
173
+
174
+
175
+ original_text = [original_text]
176
+ cider_score = (compute_max(scorer_cider, original_text, captions))
177
+ bleu1_score = (compute_max(bleu1, original_text, captions))
178
+ bleu2_score = (compute_max(bleu2, original_text, captions))
179
+ bleu3_score = (compute_max(bleu3, original_text, captions))
180
+ bleu4_score = (compute_max(bleu4, original_text, captions))
181
+
182
+ blip_bleu_caps_avg = (bleu1_score + bleu2_score + bleu3_score + bleu4_score)/4
183
+
184
+ return blip_bleu_caps_avg
185
+
186
+ def calculate_sd_score(video_path, text, pipe, model):
187
+ # Load the video
188
+ output_dir = "../../SDXL_Imgs"
189
+ if not os.path.exists(output_dir):
190
+ os.mkdir(output_dir)
191
+ cap = cv2.VideoCapture(video_path)
192
+ to_tensor = transforms.ToTensor()
193
+ # Extract frames from the video
194
+ frames = []
195
+ SD_images = []
196
+ Num = 5
197
+ resize = transforms.Resize([224,224])
198
+ while cap.isOpened():
199
+ ret, frame = cap.read()
200
+ if not ret:
201
+ break
202
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
203
+ # resized_frame = cv2.resize(frame,(224,224)) # Resize the frame to match the expected input size
204
+ frames.append(frame)
205
+
206
+ # Load SD imgs from local paths
207
+ for i in range(Num): ## Num images for every prompt
208
+ output_dir = "../../SDXL_Imgs"
209
+ # ipdb.set_trace()
210
+ SD_image_path = os.path.join(output_dir, f"{os.path.basename(video_path).split('.')[0]}_{i}.png")
211
+ if os.path.exists(SD_image_path):
212
+ image = Image.open(SD_image_path)
213
+ # Convert the image to a tensor
214
+ image = resize(to_tensor(image))
215
+ SD_images.append(image.unsqueeze(0))
216
+ else:
217
+ image = pipe(text, height = 512, width= 512, num_inference_steps = 20).images[0] #!!!!! same amount of SD images, but also can be mutiple times, TODO
218
+ # Convert the image to a tensor
219
+ image = resize(to_tensor(image))
220
+ SD_images.append(image.unsqueeze(0))
221
+ save_image(image,SD_image_path)
222
+
223
+ tensor_frames = [resize(torch.from_numpy(frame).permute(2, 0, 1).float()) for frame in frames]
224
+ SD_images = torch.cat(SD_images, 0)
225
+
226
+ concatenated_frame_features = []
227
+ concatenated_SDImg_features = []
228
+ # Generate embeddings for each frame and concatenate the features
229
+ with torch.no_grad():
230
+ for frame in tensor_frames: # Too many frames in a video, must split before CLIP embedding, limited by the memory
231
+ frame_input = frame.unsqueeze(0).to(device) # Add batch dimension and move the frame to the device
232
+ frame_feature = model.get_image_features(frame_input)
233
+ concatenated_frame_features.append(frame_feature)
234
+
235
+ for i in range(SD_images.size()[0]):
236
+ img = SD_images[i].unsqueeze(0).to(device) # Add batch dimension and move the frame to the device
237
+ SDImg_feature = model.get_image_features(img)
238
+ concatenated_SDImg_features.append(SDImg_feature)
239
+ # ipdb.set_trace()
240
+ concatenated_frame_features = torch.cat(concatenated_frame_features, dim=0)
241
+ concatenated_SDImg_features = torch.cat(concatenated_SDImg_features, dim=0)
242
+
243
+ # Calculate the similarity scores
244
+ concatenated_frame_features = concatenated_frame_features / concatenated_frame_features.norm(p=2, dim=-1, keepdim=True)
245
+ concatenated_SDImg_features = concatenated_SDImg_features / concatenated_SDImg_features.norm(p=2, dim=-1, keepdim=True)
246
+ sd_score_frames = concatenated_frame_features @ concatenated_SDImg_features.T
247
+ # Calculate the average CLIP score across all frames, reflects temporal consistency
248
+ sd_score_frames_avg = sd_score_frames.mean().item()
249
+
250
+ return sd_score_frames_avg
251
+
252
+ def calculate_face_consistency_score(video_path, model):
253
+ # Load the video
254
+ cap = cv2.VideoCapture(video_path)
255
+ to_tensor = transforms.ToTensor()
256
+ # Extract frames from the video
257
+ frames = []
258
+ SD_images = []
259
+ resize = transforms.Resize([224,224])
260
+ while cap.isOpened():
261
+ ret, frame = cap.read()
262
+ if not ret:
263
+ break
264
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
265
+ # resized_frame = cv2.resize(frame,(224,224)) # Resize the frame to match the expected input size
266
+ frames.append(frame)
267
+
268
+ tensor_frames = [resize(torch.from_numpy(frame).permute(2, 0, 1).float()) for frame in frames]
269
+ concatenated_frame_features = []
270
+
271
+ # Generate embeddings for each frame and concatenate the features
272
+ with torch.no_grad():
273
+ for frame in tensor_frames: # Too many frames in a video, must split before CLIP embedding, limited by the memory
274
+ frame_input = frame.unsqueeze(0).to(device) # Add batch dimension and move the frame to the device
275
+ frame_feature = model.get_image_features(frame_input)
276
+ concatenated_frame_features.append(frame_feature)
277
+
278
+ concatenated_frame_features = torch.cat(concatenated_frame_features, dim=0)
279
+
280
+ # Calculate the similarity scores
281
+ concatenated_frame_features = concatenated_frame_features / concatenated_frame_features.norm(p=2, dim=-1, keepdim=True)
282
+ face_consistency_score = concatenated_frame_features[1:] @ concatenated_frame_features[0].unsqueeze(0).T
283
+ # Calculate the average CLIP score across all frames, reflects temporal consistency
284
+ face_consistency_score_avg = face_consistency_score.mean().item()
285
+
286
+ return face_consistency_score_avg
287
+
288
+ def read_text_file(file_path):
289
+ with open(file_path, 'r') as f:
290
+ return f.read().strip()
291
+
292
+
293
+ if __name__ == '__main__':
294
+ parser = argparse.ArgumentParser()
295
+ parser.add_argument("--dir_videos", type=str, default='', help="Specify the path of generated videos")
296
+ parser.add_argument("--metric", type=str, default='celebrity_id_score', help="Specify the metric to be used")
297
+ args = parser.parse_args()
298
+
299
+ dir_videos = args.dir_videos
300
+ metric = args.metric
301
+
302
+ dir_prompts = '../../prompts/'
303
+
304
+ video_paths = [os.path.join(dir_videos, x) for x in os.listdir(dir_videos)]
305
+ prompt_paths = [os.path.join(dir_prompts, os.path.splitext(os.path.basename(x))[0]+'.txt') for x in video_paths]
306
+
307
+ # Create the directory if it doesn't exist
308
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
309
+ os.makedirs(f"../../results", exist_ok=True)
310
+ # Set up logging
311
+ log_file_path = f"../../results/{metric}_record.txt"
312
+ # Delete the log file if it exists
313
+ if os.path.exists(log_file_path):
314
+ os.remove(log_file_path)
315
+ # Set up logging
316
+ logger = logging.getLogger()
317
+ logger.setLevel(logging.INFO)
318
+ # File handler for writing logs to a file
319
+ file_handler = logging.FileHandler(filename=f"../../results/{metric}_record.txt")
320
+ file_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
321
+ logger.addHandler(file_handler)
322
+ # Stream handler for displaying logs in the terminal
323
+ stream_handler = logging.StreamHandler()
324
+ stream_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
325
+ logger.addHandler(stream_handler)
326
+
327
+
328
+ # Load pretrained models
329
+ device = "cuda" if torch.cuda.is_available() else "cpu"
330
+
331
+
332
+ if metric == 'blip_bleu':
333
+ blip2_processor = AutoProcessor.from_pretrained("../../checkpoints/blip2-opt-2.7b")
334
+ blip2_model = Blip2ForConditionalGeneration.from_pretrained("../../checkpoints/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
335
+ elif metric == 'sd_score':
336
+ clip_model = CLIPModel.from_pretrained("../../checkpoints/clip-vit-base-patch32").to(device)
337
+ clip_tokenizer = AutoTokenizer.from_pretrained("../../checkpoints/clip-vit-base-patch32")
338
+ output_dir = "/apdcephfs/share_1290939/raphaelliu/Vid_Eval/Video_Gen/prompt700-release/SDXL_Imgs"
339
+ SD_image_path = os.path.join(output_dir, f"{os.path.basename(os.path.basename(video_paths[0]).split('.')[0])}_0.png")
340
+ # if os.path.exists(SD_image_path):
341
+ # pipe = None
342
+ # else:
343
+ pipe = StableDiffusionXLPipeline.from_pretrained(
344
+ "../../checkpoints/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
345
+ pipe = pipe.to(device)
346
+ else:
347
+ clip_model = CLIPModel.from_pretrained("../../checkpoints/clip-vit-base-patch32").to(device)
348
+ clip_tokenizer = AutoTokenizer.from_pretrained("../../checkpoints/clip-vit-base-patch32")
349
+
350
+ # Calculate SD scores for all video-text pairs
351
+ scores = []
352
+
353
+ test_num = 10
354
+ test_num = len(video_paths)
355
+ count = 0
356
+ for i in tqdm(range(len(video_paths))):
357
+ video_path = video_paths[i]
358
+ prompt_path = prompt_paths[i]
359
+ if count == test_num:
360
+ break
361
+ else:
362
+ text = read_text_file(prompt_path)
363
+ # ipdb.set_trace()
364
+ if metric == 'clip_score':
365
+ score = calculate_clip_score(video_path, text, clip_model, clip_tokenizer)
366
+ elif metric == 'blip_bleu':
367
+ score = calculate_blip_bleu(video_path, text, blip2_model, blip2_processor)
368
+ elif metric == 'sd_score':
369
+ score = calculate_sd_score(video_path, text, pipe,clip_model)
370
+ elif metric == 'clip_temp_score':
371
+ score = calculate_clip_temp_score(video_path,clip_model)
372
+ elif metric == 'face_consistency_score':
373
+ score = calculate_face_consistency_score(video_path,clip_model)
374
+ count+=1
375
+ scores.append(score)
376
+ average_score = sum(scores) / len(scores)
377
+ # count+=1
378
+ logging.info(f"Vid: {os.path.basename(video_path)}, Current {metric}: {score}, Current avg. {metric}: {average_score}, ")
379
+
380
+ # Calculate the average SD score across all video-text pairs
381
+ logging.info(f"Final average {metric}: {average_score}, Total videos: {len(scores)}")