enclap / modeling /enclap_bart.py
tonyswoo's picture
Initial Commit
73baeae
raw
history blame
22.5 kB
import math
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutput,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from transformers.models.bart.configuration_bart import BartConfig
from transformers.models.bart.modeling_bart import (
BartDecoder,
BartEncoderLayer,
BartForConditionalGeneration,
BartLearnedPositionalEmbedding,
BartModel,
BartPretrainedModel,
_expand_mask,
shift_tokens_right,
)
from transformers.utils import logging
from .modeling_outputs import EnClapBartOutput
logger = logging.get_logger(__name__)
class EnClapBartConfig(BartConfig):
def __init__(
self,
d_clap: int = 512,
num_rvq: int = 16,
encodec_vocab_size: int = 1024,
encodec_pad_token_id: int = 1024,
mcm_loss_scale: float = 0.7,
label_smoothing: float = 0.2,
**kwargs,
):
super().__init__(**kwargs)
self.d_clap = d_clap
self.num_rvq = num_rvq
self.encodec_vocab_size = encodec_vocab_size
self.encodec_pad_token_id = encodec_pad_token_id
self.mcm_loss_scale = mcm_loss_scale
self.label_smoothing = label_smoothing
class EnClapBartEncoder(BartPretrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self, config: EnClapBartConfig, embed_tokens: Optional[nn.Embedding] = None
):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
clap_dim = config.d_clap
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_encodec = nn.ModuleList(
[
nn.Embedding(
math.ceil((config.encodec_vocab_size + 1) / 64) * 64,
config.d_model,
padding_idx=config.encodec_pad_token_id,
)
for _ in range(config.num_rvq)
]
)
self.clap_projection = nn.Linear(clap_dim, embed_dim)
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList(
[BartEncoderLayer(config) for _ in range(config.encoder_layers)]
)
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
clap_embedding: Optional[torch.Tensor] = None,
encodec_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
if input_ids.ndim == 2: # This is effectively just input = input_ids
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
if input_ids.ndim == 2:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
elif input_ids.ndim == 3:
encodec_ids = torch.where(encodec_mask.unsqueeze(-1) > 0, input_ids, 0)
encodec_embeds = torch.zeros(
input_ids.shape[0], input_ids.shape[1], self.config.d_model
).to(self.device)
for i, embed in enumerate(self.embed_encodec):
encodec_embeds = encodec_embeds + embed(encodec_ids[..., i])
bart_ids = torch.where(encodec_mask == 0, input_ids[..., 0], 0)
bart_embeds = self.embed_tokens(bart_ids)
input_embeds = torch.where(
encodec_mask.unsqueeze(-1) > 0, encodec_embeds, bart_embeds
)
# Get CLAP embedding
if clap_embedding is not None:
clap_embedding = self.clap_projection(clap_embedding)
input_embeds[:, 0] = clap_embedding
inputs_embeds = input_embeds.to(self.device)
batch_size = input_ids.size(0)
embed_pos = self.embed_positions(input_ids).to(self.device)
embed_pos = torch.cat(
[
torch.zeros(batch_size, 1, self.config.d_model).to(self.device),
embed_pos[:, :-1],
],
dim=1,
)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (
dropout_probability < self.layerdrop
): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class EnClapBartModel(BartModel):
def __init__(self, config: EnClapBartConfig):
super(BartModel, self).__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = EnClapBartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
clap_embedding: Optional[torch.Tensor] = None,
encodec_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqModelOutput]:
# different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided
if decoder_input_ids is None and decoder_inputs_embeds is None:
if input_ids is None:
raise ValueError(
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
"passed, `input_ids` cannot be `None`. Please pass either "
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
)
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
clap_embedding=clap_embedding,
encodec_mask=encodec_mask,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class EnClapBartForConditionalGeneration(BartForConditionalGeneration):
config_class = EnClapBartConfig
def __init__(self, config: EnClapBartConfig):
super(BartForConditionalGeneration, self).__init__(config)
self.model = EnClapBartModel(config)
self.register_buffer(
"final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
)
self.lm_head = nn.Linear(
config.d_model, self.model.shared.num_embeddings, bias=False
)
self.mcm_heads = nn.ModuleList(
[
nn.Linear(config.d_model, config.encodec_vocab_size)
for _ in range(config.num_rvq)
]
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
clap_embedding: Optional[torch.Tensor] = None,
encodec_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
mcm_labels: Optional[List[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
if use_cache:
logger.warning(
"The `use_cache` argument is changed to `False` since `labels` is provided."
)
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
clap_embedding=clap_embedding,
encodec_mask=encodec_mask,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
mcm_loss = None
if mcm_labels is not None:
mcm_loss = 0.0
loss_fct = CrossEntropyLoss()
for i, mcm_head in enumerate(self.mcm_heads):
mcm_logits = mcm_head(outputs.encoder_last_hidden_state)
loss_scale = 1 / 2 ** (i + 1)
loss = loss_fct(
mcm_logits.view(-1, self.config.encodec_vocab_size),
mcm_labels[..., i].reshape(-1),
)
mcm_loss = mcm_loss + loss * loss_scale
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
masked_lm_loss = loss_fct(
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
)
loss = None
if mcm_loss is None:
loss = masked_lm_loss
elif masked_lm_loss is None:
loss = mcm_loss
else:
mcm_loss = mcm_loss * self.config.mcm_loss_scale
loss = masked_lm_loss + mcm_loss
if not return_dict:
output = (lm_logits,) + outputs[1:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return EnClapBartOutput(
loss=loss,
lm_loss=masked_lm_loss,
mcm_loss=mcm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)