|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from vector_quantize_pytorch import GroupedResidualFSQ
|
|
|
|
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
|
|
|
|
|
|
@dataclass
|
|
class FSQResult:
|
|
z: torch.Tensor
|
|
codes: torch.Tensor
|
|
latents: torch.Tensor
|
|
|
|
|
|
class DownsampleFiniteScalarQuantize(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int = 512,
|
|
n_codebooks: int = 9,
|
|
n_groups: int = 1,
|
|
levels: tuple[int] = (8, 5, 5, 5),
|
|
downsample_factor: tuple[int] = (2, 2),
|
|
downsample_dims: tuple[int] | None = None,
|
|
):
|
|
super().__init__()
|
|
|
|
if downsample_dims is None:
|
|
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
|
|
|
all_dims = (input_dim,) + tuple(downsample_dims)
|
|
|
|
self.residual_fsq = GroupedResidualFSQ(
|
|
dim=all_dims[-1],
|
|
levels=levels,
|
|
num_quantizers=n_codebooks,
|
|
groups=n_groups,
|
|
)
|
|
|
|
self.downsample_factor = downsample_factor
|
|
self.downsample_dims = downsample_dims
|
|
|
|
self.downsample = nn.Sequential(
|
|
*[
|
|
nn.Sequential(
|
|
FishConvNet(
|
|
all_dims[idx],
|
|
all_dims[idx + 1],
|
|
kernel_size=factor,
|
|
stride=factor,
|
|
),
|
|
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
|
)
|
|
for idx, factor in enumerate(downsample_factor)
|
|
]
|
|
)
|
|
|
|
self.upsample = nn.Sequential(
|
|
*[
|
|
nn.Sequential(
|
|
FishTransConvNet(
|
|
all_dims[idx + 1],
|
|
all_dims[idx],
|
|
kernel_size=factor,
|
|
stride=factor,
|
|
),
|
|
ConvNeXtBlock(dim=all_dims[idx]),
|
|
)
|
|
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
|
]
|
|
)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, z) -> FSQResult:
|
|
original_shape = z.shape
|
|
z = self.downsample(z)
|
|
quantized, indices = self.residual_fsq(z.mT)
|
|
result = FSQResult(
|
|
z=quantized.mT,
|
|
codes=indices.mT,
|
|
latents=z,
|
|
)
|
|
result.z = self.upsample(result.z)
|
|
|
|
|
|
diff = original_shape[-1] - result.z.shape[-1]
|
|
left = diff // 2
|
|
right = diff - left
|
|
|
|
if diff > 0:
|
|
result.z = F.pad(result.z, (left, right))
|
|
elif diff < 0:
|
|
result.z = result.z[..., left:-right]
|
|
|
|
return result
|
|
|
|
def encode(self, z):
|
|
z = self.downsample(z)
|
|
_, indices = self.residual_fsq(z.mT)
|
|
indices = rearrange(indices, "g b l r -> b (g r) l")
|
|
return indices
|
|
|
|
def decode(self, indices: torch.Tensor):
|
|
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
|
z_q = self.residual_fsq.get_output_from_indices(indices)
|
|
z_q = self.upsample(z_q.mT)
|
|
return z_q
|
|
|