Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
lm: latent mapping | |
""" | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_up_block | |
FREEZE_MODULES = ['encoder', 'quant_proj', 'quantize', 'cls_emb'] | |
class Token2VAE(nn.Module): | |
def __init__( | |
self, | |
in_channels=32, | |
output_type="stats", | |
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",), | |
block_out_channels=(256, 512), | |
layers_per_block=2, | |
norm_num_groups=32, | |
act_fn="silu", | |
vq_model=None, | |
vae=None, | |
): | |
super().__init__() | |
assert output_type in ["stats", "sample"], "`output_type` can be either of 'stats' or 'sample'" | |
self.output_type = output_type | |
out_channels = 4 if output_type == "sample" else 8 | |
self.layers_per_block = layers_per_block | |
self.conv_in = nn.Conv2d( | |
in_channels, | |
block_out_channels[-1], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
self.mid_block = None | |
self.up_blocks = nn.ModuleList([]) | |
# mid | |
self.mid_block = UNetMidBlock2D( | |
in_channels=block_out_channels[-1], | |
resnet_eps=1e-6, | |
resnet_act_fn=act_fn, | |
output_scale_factor=1, | |
resnet_time_scale_shift="default", | |
attention_head_dim=block_out_channels[-1], | |
resnet_groups=norm_num_groups, | |
temb_channels=None, | |
) | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=self.layers_per_block + 1, | |
in_channels=prev_output_channel, | |
out_channels=output_channel, | |
prev_output_channel=None, | |
add_upsample=not is_final_block, | |
resnet_eps=1e-6, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
attention_head_dim=output_channel, | |
temb_channels=None, | |
resnet_time_scale_shift="group", | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) | |
self.conv_act = nn.SiLU() | |
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) | |
self.vq_model = vq_model | |
self.vae = vae | |
def vae_encode(self, x): | |
assert self.vae is not None, "VAE is not initialized" | |
z = self.vae.encode(x).latent_dist | |
if self.output_type == "sample": | |
z = z.sample() | |
else: | |
z = torch.cat((z.mean, z.std), dim=1) | |
z = z * self.vae.config.scaling_factor | |
return z | |
def vae_decode(self, x, clip=True): | |
assert self.vae is not None, "VAE is not initialized" | |
x = self.sample(x) | |
x = self.vae.decode(x / self.vae.config.scaling_factor).sample | |
if clip: | |
x = torch.clip(x, min=-1, max=1) | |
return x | |
def sample(self, x): | |
if x.shape[1] == 4: | |
return x | |
mean, std = x.chunk(2, dim=1) | |
x = mean + std * torch.randn_like(std) | |
return x | |
def forward(self, quant=None, image=None): | |
if quant is None: | |
assert image is not None, "Neither of `quant` or `image` are provided" | |
assert self.vq_model is not None, "VQ encoder is not initialized" | |
with torch.no_grad(): | |
quant, _, _ = self.vq_model.encode(image) | |
x = self.conv_in(quant) | |
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
# middle | |
x = self.mid_block(x) | |
x = x.to(upscale_dtype) | |
# up | |
for up_block in self.up_blocks: | |
x = up_block(x) | |
# post-process | |
x = self.conv_norm_out(x) | |
x = self.conv_act(x) | |
x = self.conv_out(x) | |
return x | |
def create_model( | |
in_channels=32, | |
output_type="stats", | |
vq_model=None, | |
vae=None, | |
): | |
return Token2VAE( | |
in_channels=in_channels, | |
output_type=output_type, | |
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",), | |
block_out_channels=(256, 512), | |
layers_per_block=2, | |
norm_num_groups=32, | |
act_fn="silu", | |
vq_model=vq_model, | |
vae=vae, | |
) |