# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass class UNet1DOutput(BaseOutput): """ Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): Hidden states output. Output of last layer of model. """ sample: torch.FloatTensor class UNet1DModel(ModelMixin, ConfigMixin): r""" UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.) Parameters: sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. extra_in_channels (`int`, *optional*, defaults to 0): Number of additional channels to be added to the input of the first down block. Useful for cases where the input data has more channels than what the model is initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. downsample_each_block (`int`, *optional*, defaults to False: experimental feature for using a UNet without upsampling. """ @register_to_config def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, in_channels: int = 2, out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection( embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=timestep_input_dim, time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0], ) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] if i == 0: input_channel += extra_in_channels is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] if out_block_type is None: final_upsample_channels = out_channels else: final_upsample_channels = block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = ( reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels ) is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.out_block = get_out_block( out_block_type=out_block_type, num_groups_out=num_groups_out, embed_dim=block_out_channels[0], out_channels=out_channels, act_fn=act_fn, fc_dim=block_out_channels[-1] // 4, ) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: timestep_embed = self.time_mlp(timestep_embed) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) # 2. down down_block_res_samples = () for downsample_block in self.down_blocks: sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) down_block_res_samples += res_samples # 3. mid if self.mid_block: sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) # 5. post-process if self.out_block: sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) return UNet1DOutput(sample=sample)