Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,367 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from typing import Callable
import torch
import torchvision.transforms.functional as transforms_f
from einops import rearrange
from genie.factorization_utils import factorize_labels
class AvgMetric:
""" Records a running sum and count to compute the mean. """
def __init__(self):
self.total = 0
self.count = 0
def update(self, val, batch_size=1):
self.total += val * batch_size
self.count += batch_size
def update_list(self, flat_vals):
self.total += sum(flat_vals)
self.count += len(flat_vals)
def mean(self):
if self.count == 0:
return 0
return self.total / self.count
def decode_tokens(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
"""
Converts quantized latent space tokens to images.
Args:
reshaped_token_ids: shape (B, T, H, W).
decode_latents: instance of `decode_latents_wrapper()`
Returns:
(B, T, 3, 256, 256)
"""
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w -> (b t) h w").cpu().numpy())
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
def decode_features(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
"""
Converts quantized latent space tokens to images.
Args:
reshaped_token_ids: shape (B, T, H, W).
decode_latents: instance of `decode_latents_wrapper()`
Returns:
(B, T, 3, 256, 256)
"""
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w c -> (b t) c h w").cpu().numpy())
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
def compute_loss(
labels_flat: torch.LongTensor,
factored_logits: torch.FloatTensor,
num_factored_vocabs: int = 2,
factored_vocab_size: int = 512,
) -> float:
"""
If applicable (model returns logits), compute the cross entropy loss.
In the case of a factorized vocabulary, sums the cross entropy losses for each vocabulary.
Assuming all submissions use the parametrization of num_factored_vocabs = 2, factored_vocab_size = 512
Args:
labels_flat: size (B, T*H*W) corresponding to flattened, tokenized images.
factored_logits: size (B, factored_vocab_size, num_factored_vocabs, T-1, H, W).
E.g. output of `genie.evaluate.GenieEvaluator.predict_zframe_logits()`
num_factored_vocabs: Should be 2 for v1.0 of the challenge.
factored_vocab_size: Should be 512 for v1.0 of the challenge.
Returns:
Cross entropy loss
"""
assert factored_logits.dim() == 6 \
and factored_logits.size()[:3] == (labels_flat.size(0), factored_vocab_size, num_factored_vocabs), \
f"Shape of `logits` should be (B, {factored_vocab_size}, {num_factored_vocabs}, T-1, H, W)"
t = factored_logits.size(3) + 1
h, w = factored_logits.size()[-2:]
assert t * h * w == labels_flat.size(1), "Shape of `factored_logits` does not match flattened latent image size."
labels_THW = rearrange(labels_flat, "b (t h w) -> b t h w", t=t, h=h, w=w)
labels_THW = labels_THW[:, 1:].to(factored_logits.device)
factored_labels = factorize_labels(labels_THW, num_factored_vocabs, factored_vocab_size)
return torch.nn.functional.cross_entropy(factored_logits, factored_labels, reduction="none")\
.sum(dim=1).mean().item() # Final loss is the sum of the two losses across the size-512 vocabularies
def compute_lpips(frames_a: torch.ByteTensor, frames_b: torch.ByteTensor, lpips_func: Callable) -> list:
"""
Given two batches of video data, of shape (B, T, 3, 256, 256), computes the LPIPS score on frame-by-frame level.
Cannot use `lpips_func` directly because it expects at most 4D input.
"""
# LPIPS expects pixel values between [-1, 1]
flattened_a, flattened_b = [rearrange(frames / 127.5 - 1, "b t c H W -> (b t) c H W")
for frames in (frames_a, frames_b)]
return lpips_func(flattened_a, flattened_b).flatten().tolist()
|