aroraaman's picture
Add all of `fourm`
3424266
raw
history blame
5.49 kB
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
lm: latent mapping
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_up_block
FREEZE_MODULES = ['encoder', 'quant_proj', 'quantize', 'cls_emb']
class Token2VAE(nn.Module):
def __init__(
self,
in_channels=32,
output_type="stats",
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",),
block_out_channels=(256, 512),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
vq_model=None,
vae=None,
):
super().__init__()
assert output_type in ["stats", "sample"], "`output_type` can be either of 'stats' or 'sample'"
self.output_type = output_type
out_channels = 4 if output_type == "sample" else 8
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
resnet_time_scale_shift="group",
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.vq_model = vq_model
self.vae = vae
@torch.no_grad()
def vae_encode(self, x):
assert self.vae is not None, "VAE is not initialized"
z = self.vae.encode(x).latent_dist
if self.output_type == "sample":
z = z.sample()
else:
z = torch.cat((z.mean, z.std), dim=1)
z = z * self.vae.config.scaling_factor
return z
@torch.no_grad()
def vae_decode(self, x, clip=True):
assert self.vae is not None, "VAE is not initialized"
x = self.sample(x)
x = self.vae.decode(x / self.vae.config.scaling_factor).sample
if clip:
x = torch.clip(x, min=-1, max=1)
return x
def sample(self, x):
if x.shape[1] == 4:
return x
mean, std = x.chunk(2, dim=1)
x = mean + std * torch.randn_like(std)
return x
def forward(self, quant=None, image=None):
if quant is None:
assert image is not None, "Neither of `quant` or `image` are provided"
assert self.vq_model is not None, "VQ encoder is not initialized"
with torch.no_grad():
quant, _, _ = self.vq_model.encode(image)
x = self.conv_in(quant)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
x = self.mid_block(x)
x = x.to(upscale_dtype)
# up
for up_block in self.up_blocks:
x = up_block(x)
# post-process
x = self.conv_norm_out(x)
x = self.conv_act(x)
x = self.conv_out(x)
return x
def create_model(
in_channels=32,
output_type="stats",
vq_model=None,
vae=None,
):
return Token2VAE(
in_channels=in_channels,
output_type=output_type,
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",),
block_out_channels=(256, 512),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
vq_model=vq_model,
vae=vae,
)