Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable | |
import torch | |
import torchvision.transforms.functional as transforms_f | |
from einops import rearrange | |
from genie.factorization_utils import factorize_labels | |
class AvgMetric: | |
""" Records a running sum and count to compute the mean. """ | |
def __init__(self): | |
self.total = 0 | |
self.count = 0 | |
def update(self, val, batch_size=1): | |
self.total += val * batch_size | |
self.count += batch_size | |
def update_list(self, flat_vals): | |
self.total += sum(flat_vals) | |
self.count += len(flat_vals) | |
def mean(self): | |
if self.count == 0: | |
return 0 | |
return self.total / self.count | |
def decode_tokens(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor: | |
""" | |
Converts quantized latent space tokens to images. | |
Args: | |
reshaped_token_ids: shape (B, T, H, W). | |
decode_latents: instance of `decode_latents_wrapper()` | |
Returns: | |
(B, T, 3, 256, 256) | |
""" | |
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w -> (b t) h w").cpu().numpy()) | |
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs]) | |
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0)) | |
def decode_features(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor: | |
""" | |
Converts quantized latent space tokens to images. | |
Args: | |
reshaped_token_ids: shape (B, T, H, W). | |
decode_latents: instance of `decode_latents_wrapper()` | |
Returns: | |
(B, T, 3, 256, 256) | |
""" | |
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w c -> (b t) c h w").cpu().numpy()) | |
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs]) | |
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0)) | |
def compute_loss( | |
labels_flat: torch.LongTensor, | |
factored_logits: torch.FloatTensor, | |
num_factored_vocabs: int = 2, | |
factored_vocab_size: int = 512, | |
) -> float: | |
""" | |
If applicable (model returns logits), compute the cross entropy loss. | |
In the case of a factorized vocabulary, sums the cross entropy losses for each vocabulary. | |
Assuming all submissions use the parametrization of num_factored_vocabs = 2, factored_vocab_size = 512 | |
Args: | |
labels_flat: size (B, T*H*W) corresponding to flattened, tokenized images. | |
factored_logits: size (B, factored_vocab_size, num_factored_vocabs, T-1, H, W). | |
E.g. output of `genie.evaluate.GenieEvaluator.predict_zframe_logits()` | |
num_factored_vocabs: Should be 2 for v1.0 of the challenge. | |
factored_vocab_size: Should be 512 for v1.0 of the challenge. | |
Returns: | |
Cross entropy loss | |
""" | |
assert factored_logits.dim() == 6 \ | |
and factored_logits.size()[:3] == (labels_flat.size(0), factored_vocab_size, num_factored_vocabs), \ | |
f"Shape of `logits` should be (B, {factored_vocab_size}, {num_factored_vocabs}, T-1, H, W)" | |
t = factored_logits.size(3) + 1 | |
h, w = factored_logits.size()[-2:] | |
assert t * h * w == labels_flat.size(1), "Shape of `factored_logits` does not match flattened latent image size." | |
labels_THW = rearrange(labels_flat, "b (t h w) -> b t h w", t=t, h=h, w=w) | |
labels_THW = labels_THW[:, 1:].to(factored_logits.device) | |
factored_labels = factorize_labels(labels_THW, num_factored_vocabs, factored_vocab_size) | |
return torch.nn.functional.cross_entropy(factored_logits, factored_labels, reduction="none")\ | |
.sum(dim=1).mean().item() # Final loss is the sum of the two losses across the size-512 vocabularies | |
def compute_lpips(frames_a: torch.ByteTensor, frames_b: torch.ByteTensor, lpips_func: Callable) -> list: | |
""" | |
Given two batches of video data, of shape (B, T, 3, 256, 256), computes the LPIPS score on frame-by-frame level. | |
Cannot use `lpips_func` directly because it expects at most 4D input. | |
""" | |
# LPIPS expects pixel values between [-1, 1] | |
flattened_a, flattened_b = [rearrange(frames / 127.5 - 1, "b t c H W -> (b t) c H W") | |
for frames in (frames_a, frames_b)] | |
return lpips_func(flattened_a, flattened_b).flatten().tolist() | |