Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from functools import partial | |
import logging | |
import math | |
import typing as tp | |
import torch | |
from torch import nn | |
from torchdiffeq import odeint # type: ignore | |
from ..modules.streaming import StreamingModule | |
from ..modules.transformer import create_norm_fn, StreamingTransformerLayer | |
from ..modules.unet_transformer import UnetTransformer | |
from ..modules.conditioners import ( | |
ConditionFuser, | |
ClassifierFreeGuidanceDropout, | |
AttributeDropout, | |
ConditioningAttributes, | |
JascoCondConst | |
) | |
from ..modules.jasco_conditioners import JascoConditioningProvider | |
from ..modules.activations import get_activation_fn | |
from .lm import ConditionTensors, init_layer | |
logger = logging.getLogger(__name__) | |
class FMOutput: | |
latents: torch.Tensor # [B, T, D] | |
mask: torch.Tensor # [B, T] | |
class CFGTerm: | |
""" | |
Base class for Multi Source Classifier-Free Guidance (CFG) terms. This class represents a term in the CFG process, | |
which is used to guide the generation process by adjusting the influence of different conditions. | |
Attributes: | |
conditions (dict): A dictionary of conditions that influence the generation process. | |
weight (float): The weight of the CFG term, determining its influence on the generation. | |
""" | |
def __init__(self, conditions, weight): | |
self.conditions = conditions | |
self.weight = weight | |
def drop_irrelevant_conds(self, conditions): | |
""" | |
Drops irrelevant conditions from the CFG term. This method should be implemented by subclasses. | |
Args: | |
conditions (dict): The conditions to be filtered. | |
Raises: | |
NotImplementedError: If the method is not implemented in a subclass. | |
""" | |
raise NotImplementedError("No base implementation for setting generation params.") | |
class AllCFGTerm(CFGTerm): | |
""" | |
A CFG term that retains all conditions. This class does not drop any condition. | |
""" | |
def __init__(self, conditions, weight): | |
super().__init__(conditions, weight) | |
self.drop_irrelevant_conds() | |
def drop_irrelevant_conds(self): | |
pass | |
class NullCFGTerm(CFGTerm): | |
""" | |
A CFG term that drops all conditions, effectively nullifying their influence. | |
""" | |
def __init__(self, conditions, weight): | |
super().__init__(conditions, weight) | |
self.drop_irrelevant_conds() | |
def drop_irrelevant_conds(self): | |
""" | |
Drops all conditions by applying a dropout with probability 1.0, effectively nullifying their influence. | |
""" | |
self.conditions = ClassifierFreeGuidanceDropout(p=1.0)( | |
samples=self.conditions, | |
cond_types=["wav", "text", "symbolic"]) | |
class TextCFGTerm(CFGTerm): | |
""" | |
A CFG term that selectively drops conditions based on specified dropout probabilities for different types | |
of conditions, such as 'symbolic' and 'wav'. | |
""" | |
def __init__(self, conditions, weight, model_att_dropout): | |
""" | |
Initializes a TextCFGTerm with specified conditions, weight, and model attention dropout configuration. | |
Args: | |
conditions (dict): The conditions to be used in the CFG process. | |
weight (float): The weight of the CFG term. | |
model_att_dropout (object): The attribute dropouts used by the model. | |
""" | |
super().__init__(conditions, weight) | |
if 'symbolic' in model_att_dropout.p: | |
self.drop_symbolics = {k: 1.0 for k in model_att_dropout.p['symbolic'].keys()} | |
else: | |
self.drop_symbolics = {} | |
if 'wav' in model_att_dropout.p: | |
self.drop_wav = {k: 1.0 for k in model_att_dropout.p['wav'].keys()} | |
else: | |
self.drop_wav = {} | |
self.drop_irrelevant_conds() | |
def drop_irrelevant_conds(self): | |
self.conditions = AttributeDropout({'symbolic': self.drop_symbolics, | |
'wav': self.drop_wav})(self.conditions) # drop temporal conds | |
class FlowMatchingModel(StreamingModule): | |
""" | |
A flow matching model inherits from StreamingModule. | |
This model uses a transformer architecture to process and fuse conditions, applying learned embeddings and | |
transformations and predicts multi-source guided vector fields. | |
Attributes: | |
condition_provider (JascoConditioningProvider): Provider for conditioning attributes. | |
fuser (ConditionFuser): Fuser for combining multiple conditions. | |
dim (int): Dimensionality of the model's main features. | |
num_heads (int): Number of attention heads in the transformer. | |
flow_dim (int): Dimensionality of the flow features. | |
chords_dim (int): Dimensionality for chord embeddings, if used. | |
drums_dim (int): Dimensionality for drums embeddings, if used. | |
melody_dim (int): Dimensionality for melody embeddings, if used. | |
hidden_scale (int): Scaling factor for the dimensionality of the feedforward network in the transformer. | |
norm (str): Type of normalization to use ('layer_norm' or other supported types). | |
norm_first (bool): Whether to apply normalization before other operations in the transformer layers. | |
bias_proj (bool): Whether to include bias in the projection layers. | |
weight_init (Optional[str]): Method for initializing weights. | |
depthwise_init (Optional[str]): Method for initializing depthwise convolutional layers. | |
zero_bias_init (bool): Whether to initialize biases to zero. | |
cfg_dropout (float): Dropout rate for configuration settings. | |
cfg_coef (float): Coefficient for configuration influence. | |
attribute_dropout (Dict[str, Dict[str, float]]): Dropout rates for specific attributes. | |
time_embedding_dim (int): Dimensionality of time embeddings. | |
**kwargs: Additional keyword arguments for the transformer. | |
Methods: | |
__init__: Initializes the model with the specified attributes and configuration. | |
""" | |
def __init__(self, condition_provider: JascoConditioningProvider, | |
fuser: ConditionFuser, | |
dim: int = 128, | |
num_heads: int = 8, | |
flow_dim: int = 128, | |
chords_dim: int = 0, | |
drums_dim: int = 0, | |
melody_dim: int = 0, | |
hidden_scale: int = 4, | |
norm: str = 'layer_norm', | |
norm_first: bool = False, | |
bias_proj: bool = True, | |
weight_init: tp.Optional[str] = None, | |
depthwise_init: tp.Optional[str] = None, | |
zero_bias_init: bool = False, | |
cfg_dropout: float = 0, | |
cfg_coef: float = 1.0, | |
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, | |
time_embedding_dim: int = 128, | |
**kwargs): | |
super().__init__() | |
self.cfg_coef = cfg_coef | |
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) | |
self.att_dropout = AttributeDropout(p=attribute_dropout) | |
self.condition_provider = condition_provider | |
self.fuser = fuser | |
self.dim = dim # transformer dim | |
self.flow_dim = flow_dim | |
self.chords_dim = chords_dim | |
self.emb = nn.Linear(flow_dim + chords_dim + drums_dim + melody_dim, dim, bias=False) | |
if 'activation' in kwargs: | |
kwargs['activation'] = get_activation_fn(kwargs['activation']) | |
self.transformer = UnetTransformer( | |
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), | |
norm=norm, norm_first=norm_first, | |
layer_class=StreamingTransformerLayer, | |
**kwargs) | |
self.out_norm: tp.Optional[nn.Module] = None | |
if norm_first: | |
self.out_norm = create_norm_fn(norm, dim) | |
self.linear = nn.Linear(dim, flow_dim, bias=bias_proj) | |
self._init_weights(weight_init, depthwise_init, zero_bias_init) | |
self._fsdp: tp.Optional[nn.Module] | |
self.__dict__['_fsdp'] = None | |
# init time parameter embedding | |
self.d_temb1 = time_embedding_dim | |
self.d_temb2 = 4 * time_embedding_dim | |
self.temb = nn.Module() | |
self.temb.dense = nn.ModuleList([ | |
torch.nn.Linear(self.d_temb1, | |
self.d_temb2), | |
torch.nn.Linear(self.d_temb2, | |
self.d_temb2), | |
]) | |
self.temb_proj = nn.Linear(self.d_temb2, dim) | |
def _get_timestep_embedding(self, timesteps, embedding_dim): | |
""" | |
####################################################################################################### | |
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py | |
####################################################################################################### | |
This matches the implementation in Denoising Diffusion Probabilistic Models: | |
From Fairseq. | |
Build sinusoidal embeddings. | |
This matches the implementation in tensor2tensor, but differs slightly | |
from the description in Section 3.5 of "Attention Is All You Need". | |
""" | |
assert len(timesteps.shape) == 1 | |
half_dim = embedding_dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) | |
emb = emb.to(device=timesteps.device) | |
emb = timesteps.float()[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
return emb | |
def _embed_time_parameter(self, t: torch.Tensor): | |
""" | |
####################################################################################################### | |
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py | |
####################################################################################################### | |
""" | |
temb = self._get_timestep_embedding(t.flatten(), self.d_temb1) | |
temb = self.temb.dense[0](temb) | |
temb = temb * torch.sigmoid(temb) # swish activation | |
temb = self.temb.dense[1](temb) | |
return temb | |
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): | |
"""Initialization of the transformer module weights. | |
Args: | |
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. | |
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: | |
'current' where the depth corresponds to the current layer index or 'global' where the total number | |
of layer is used as depth. If not set, no depthwise initialization strategy is used. | |
zero_bias_init (bool): Whether to initialize bias to zero or not. | |
""" | |
assert depthwise_init is None or depthwise_init in ['current', 'global'] | |
assert depthwise_init is None or weight_init is not None, \ | |
"If 'depthwise_init' is defined, a 'weight_init' method should be provided." | |
assert not zero_bias_init or weight_init is not None, \ | |
"If 'zero_bias_init', a 'weight_init' method should be provided" | |
if weight_init is None: | |
return | |
init_layer(self.emb, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
for layer_idx, tr_layer in enumerate(self.transformer.layers): | |
depth = None | |
if depthwise_init == 'current': | |
depth = layer_idx + 1 | |
elif depthwise_init == 'global': | |
depth = len(self.transformer.layers) | |
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) | |
tr_layer.apply(init_fn) | |
init_layer(self.linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
def _align_seq_length(self, | |
cond: torch.Tensor, | |
seq_len: int = 500): | |
# trim if needed | |
cond = cond[:, :seq_len, :] | |
# pad if needed | |
B, T, C = cond.shape | |
if T < seq_len: | |
cond = torch.cat((cond, torch.zeros((B, seq_len - T, C), dtype=cond.dtype, device=cond.device)), dim=1) | |
return cond | |
def forward(self, | |
latents: torch.Tensor, | |
t: torch.Tensor, | |
conditions: tp.List[ConditioningAttributes], | |
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: | |
"""Apply flow matching forward pass on latents and conditions. | |
Given a tensor of noisy latents of shape [B, T, D] with D the flow dim and T the sequence steps, | |
and a time parameter tensor t, return the vector field with shape [B, T, D]. | |
Args: | |
latents (torch.Tensor): noisy latents. | |
conditions (list of ConditioningAttributes): Conditions to use when modeling | |
the given codes. Note that when evaluating multiple time with the same conditioning | |
you should pre-compute those and pass them as `condition_tensors`. | |
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning | |
tensors, see `conditions`. | |
Returns: | |
torch.Tensor: estimated vector field v_theta. | |
""" | |
assert condition_tensors is not None, "FlowMatchingModel require pre-calculation of condition tensors" | |
assert not conditions, "Shouldn't pass unprocessed conditions to FlowMatchingModel." | |
B, T, D = latents.shape | |
x = latents | |
# concat temporal conditions on the feature dimension | |
temporal_conds = JascoCondConst.ALL.value | |
for cond in temporal_conds: | |
if cond not in condition_tensors: | |
continue | |
c = self._align_seq_length(condition_tensors[cond][0], seq_len=T) | |
x = torch.concat((x, c), dim=-1) | |
# project to transformer dimension | |
input_ = self.emb(x) | |
input_, cross_attention_input = self.fuser(input_, condition_tensors) | |
# embed time parameter | |
t_embs = self._embed_time_parameter(t) | |
# add it to cross_attention_input | |
cross_attention_input = cross_attention_input + self.temb_proj(t_embs[:, None, :]) | |
out = self.transformer(input_, cross_attention_src=cross_attention_input) | |
if self.out_norm: | |
out = self.out_norm(out) | |
v_theta = self.linear(out) # [B, T, D] | |
# remove the prefix from the model outputs | |
if len(self.fuser.fuse2cond['prepend']) > 0: | |
v_theta = v_theta[:, :, -T:] | |
return v_theta # [B, T, D] | |
def _multi_source_cfg_preprocess(self, | |
conditions: tp.List[ConditioningAttributes], | |
cfg_coef_all: float, | |
cfg_coef_txt: float, | |
min_weight: float = 1e-6): | |
""" | |
Preprocesses the CFG terms for multi-source conditional generation. | |
Args: | |
conditions (list): A list of conditions to be applied. | |
cfg_coef_all (float): The coefficient for all conditions. | |
cfg_coef_txt (float): The coefficient for text conditions. | |
min_weight (float): The minimal absolute weight for calculating a CFG term. | |
Returns: | |
tuple: A tuple containing condition_tensors and cfg_terms. | |
condition_tensors is a dictionary or ConditionTensors object with tokenized conditions. | |
cfg_terms is a list of CFGTerm objects with weights adjusted based on the coefficients. | |
""" | |
condition_tensors: tp.Optional[ConditionTensors] | |
cfg_terms = [] | |
if conditions: | |
# conditional terms | |
cfg_terms = [AllCFGTerm(conditions=conditions, weight=cfg_coef_all), | |
TextCFGTerm(conditions=conditions, weight=cfg_coef_txt, | |
model_att_dropout=self.att_dropout)] | |
# add null term | |
cfg_terms.append(NullCFGTerm(conditions=conditions, weight=1 - sum([ct.weight for ct in cfg_terms]))) | |
# remove terms with negligible weight | |
for ct in cfg_terms: | |
if abs(ct.weight) < min_weight: | |
cfg_terms.remove(ct) | |
conds: tp.List[ConditioningAttributes] = sum([ct.conditions for ct in cfg_terms], []) | |
tokenized = self.condition_provider.tokenize(conds) | |
condition_tensors = self.condition_provider(tokenized) | |
else: | |
condition_tensors = {} | |
return condition_tensors, cfg_terms | |
def estimated_vector_field(self, z, t, condition_tensors=None, cfg_terms=[]): | |
""" | |
Estimates the vector field for the given latent variables and time parameter, | |
conditioned on the provided conditions. | |
Args: | |
z (Tensor): The latent variables. | |
t (float): The time variable. | |
condition_tensors (ConditionTensors, optional): The condition tensors. Defaults to None. | |
cfg_terms (list, optional): The list of CFG terms. Defaults to an empty list. | |
Returns: | |
Tensor: The estimated vector field. | |
""" | |
if len(cfg_terms) > 1: | |
z = z.repeat(len(cfg_terms), 1, 1) # duplicate noisy latents for multi-source CFG | |
v_thetas = self(latents=z, t=t, conditions=[], condition_tensors=condition_tensors) | |
return self._multi_source_cfg_postprocess(v_thetas, cfg_terms) | |
def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms): | |
""" | |
Postprocesses the vector fields generated for each CFG term to combine them into a single vector field. | |
Multi source guidance occurs here. | |
Args: | |
v_thetas (Tensor): The vector fields for each CFG term. | |
cfg_terms (list): The CFG terms used. | |
Returns: | |
Tensor: The combined vector field. | |
""" | |
if len(cfg_terms) <= 1: | |
return v_thetas | |
v_theta_per_term = v_thetas.chunk(len(cfg_terms)) | |
return sum([ct.weight * term_vf for ct, term_vf in zip(cfg_terms, v_theta_per_term)]) | |
def generate(self, | |
prompt: tp.Optional[torch.Tensor] = None, | |
conditions: tp.List[ConditioningAttributes] = [], | |
num_samples: tp.Optional[int] = None, | |
max_gen_len: int = 256, | |
callback: tp.Optional[tp.Callable[[int, int], None]] = None, | |
cfg_coef_all: float = 3.0, | |
cfg_coef_txt: float = 1.0, | |
euler: bool = False, | |
euler_steps: int = 100, | |
ode_rtol: float = 1e-5, | |
ode_atol: float = 1e-5, | |
) -> torch.Tensor: | |
""" | |
Generate audio latents given a prompt or unconditionally. This method supports both Euler integration | |
and adaptive ODE solving to generate sequences based on the specified conditions and configuration coefficients. | |
Args: | |
prompt (torch.Tensor, optional): Initial prompt to condition the generation. defaults to None | |
conditions (List[ConditioningAttributes]): List of conditioning attributes - text, symbolic or audio. | |
num_samples (int, optional): Number of samples to generate. | |
If None, it is inferred from the number of conditions. | |
max_gen_len (int): Maximum length of the generated sequence. | |
callback (Callable[[int, int], None], optional): Callback function to monitor the generation process. | |
cfg_coef_all (float): Coefficient for the fully conditional CFG term. | |
cfg_coef_txt (float): Coefficient for text CFG term. | |
euler (bool): If True, use Euler integration, otherwise use adaptive ODE solver. | |
euler_steps (int): Number of Euler steps to perform if Euler integration is used. | |
ode_rtol (float): ODE solver rtol threshold. | |
ode_atol (float): ODE solver atol threshold. | |
Returns: | |
torch.Tensor: Generated latents, shaped as (num_samples, max_gen_len, feature_dim). | |
""" | |
assert not self.training, "generation shouldn't be used in training mode." | |
first_param = next(iter(self.parameters())) | |
device = first_param.device | |
# Checking all input shapes are consistent. | |
possible_num_samples = [] | |
if num_samples is not None: | |
possible_num_samples.append(num_samples) | |
elif prompt is not None: | |
possible_num_samples.append(prompt.shape[0]) | |
elif conditions: | |
possible_num_samples.append(len(conditions)) | |
else: | |
possible_num_samples.append(1) | |
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" | |
num_samples = possible_num_samples[0] | |
condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(conditions, cfg_coef_all, cfg_coef_txt) | |
# flow matching inference | |
B, T, D = num_samples, max_gen_len, self.flow_dim | |
z_0 = torch.randn((B, T, D), device=device) | |
if euler: | |
# vanilla Euler intergration | |
dt = (1 / euler_steps) | |
z = z_0 | |
t = torch.zeros((1, ), device=device) | |
for _ in range(euler_steps): | |
v_theta = self.estimated_vector_field(z, t, | |
condition_tensors=condition_tensors, | |
cfg_terms=cfg_terms) | |
z = z + dt * v_theta | |
t = t + dt | |
z_1 = z | |
else: | |
# solve with dynamic ode integrator (dopri5) | |
t = torch.tensor([0, 1.0 - 1e-5], device=device) | |
num_evals = 0 | |
# define ode vector field function | |
def inner_ode_func(t, z): | |
nonlocal num_evals | |
num_evals += 1 | |
if callback is not None: | |
ESTIMATED_ODE_SOLVER_STEPS = 300 | |
callback(num_evals, ESTIMATED_ODE_SOLVER_STEPS) | |
return self.estimated_vector_field(z, t, | |
condition_tensors=condition_tensors, | |
cfg_terms=cfg_terms) | |
ode_opts: dict = {"options": {}} | |
z = odeint( | |
inner_ode_func, | |
z_0, | |
t, | |
**{"atol": ode_atol, "rtol": ode_rtol, **ode_opts}, | |
) | |
logger.info("Generated in %d steps", num_evals) | |
z_1 = z[-1] | |
return z_1 | |