Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# DISCLAIMER: This code is strongly influenced by https://github.com/lucidrains/NWT-pytorch/ | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
from einops.layers.torch import EinMix as Mix | |
class Memcodes(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
codebook_size, | |
heads = 1, | |
temperature = 1., | |
channel_last = False, | |
accept_image_fmap = True, | |
**kwargs, | |
): | |
super().__init__() | |
assert (dim % heads) == 0, 'dimension must be divisible by number of heads' | |
self.heads = heads | |
self.dim = dim | |
self.scale = (dim // heads) ** -0.5 | |
self.temperature = temperature | |
self.codebook_size = codebook_size | |
self.accept_image_fmap = accept_image_fmap | |
self.channel_last = channel_last | |
num_codebooks = heads | |
codebook_dim = dim // heads | |
self.codes = nn.Parameter(torch.randn(num_codebooks, codebook_size, codebook_dim)) | |
self.to_k = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim) | |
self.to_v = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim) | |
def indices_to_embedding(self, indices): | |
batch = indices.shape[0] | |
values = self.to_v(self.codes) | |
values = repeat(values, 'h n d -> b h n d', b = batch) | |
indices = repeat(indices, '... -> ... d', d = values.shape[-1]).squeeze(2) | |
if self.accept_image_fmap and len(indices.size())==4: | |
out = values.gather(2, indices) | |
out = rearrange(out, 'b h n d -> b (h d) n 1') | |
return out | |
else: | |
out = values.gather(2, indices.unsqueeze(2)) | |
return rearrange(out, 'b h n d -> b n (h d)') | |
def forward(self, x): | |
need_transpose = not self.channel_last and not self.accept_image_fmap | |
if self.accept_image_fmap: | |
height, width = x.shape[-2:] | |
x = rearrange(x, 'b c h w -> b (h w) c') | |
if need_transpose: | |
x = rearrange(x, 'b d n -> b n d') | |
assert x.shape[-1] == self.dim | |
# split out heads | |
q = rearrange(x, 'b n (h d) -> b h n d', h = self.heads) | |
q = q * self.scale | |
# get key / values of codes | |
k, v = self.to_k(self.codes), self.to_v(self.codes) | |
# straight through gumbel softmax | |
logits = einsum('b h i d, h j d -> b h i j', q, k) | |
if self.training: | |
attn = F.gumbel_softmax(logits, tau = self.temperature, dim = -1, hard = True) | |
codebook_indices = attn.argmax(dim = -1) | |
else: | |
codebook_indices = logits.argmax(dim = -1) | |
attn = F.one_hot(codebook_indices, num_classes = self.codebook_size).float() | |
if self.heads == 1: | |
codebook_indices = codebook_indices.squeeze(1) | |
out = einsum('b h i j, h j d -> b h i d', attn, v) | |
# merge heads | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
if need_transpose: | |
out = rearrange(out, 'b n d -> b d n') | |
if self.accept_image_fmap: | |
out = rearrange(out, 'b (h w) c -> b c h w', h = height, w = width) | |
if self.heads == 1: | |
codebook_indices = rearrange(codebook_indices, 'b (h w) -> b h w', h = height, w = width) | |
else: | |
codebook_indices = rearrange(codebook_indices, 'b n (h w) -> b n h w', h = height, w = width) | |
# Dummy codebook loss for compatibility with other types of quantizers | |
codebook_loss = torch.tensor([0.], device=x.device) | |
return out, codebook_loss, codebook_indices | |