Pathumma-llm-audio-1.0.0 / modeling_pathumma_audio.py
PATTARA TIPAKSORN
Upload 9 files
dae6ad4 verified
"""
This project adapts code from two existing open-source projects:
SALMONN by ByteDance: https://github.com/bytedance/SALMONN
Original project for multimodal language modeling including audio.
Llama 3 Typhoon Audio Preview by SCB 10X: https://huggingface.co/scb10x/llama-3-typhoon-v1.5-8b-audio-preview
An adaptation of SALMONN for audio processing.
Modifications and additional features have been implemented on top of these foundational works.
"""
import os
import gc
import math
import logging
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
Qwen2Config,
Qwen2Tokenizer,
Qwen2ForCausalLM,
WhisperConfig,
WhisperFeatureExtractor,
WhisperModel,
)
from peft import (
LoraConfig,
TaskType,
get_peft_model
)
import torchaudio.compliance.kaldi as ta_kaldi
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.models.bert.configuration_bert import BertConfig
from typing import Any, Dict, List, Optional, Tuple
from .configuration_pathumma_audio import PathummaAudioConfig
logger = logging.getLogger(__name__)
class PathummaAudioModel(PreTrainedModel):
config_class = PathummaAudioConfig
def __init__(self, config):
super().__init__(config)
if isinstance(config.torch_dtype, str):
self.torch_dtype = getattr(torch, config.torch_dtype, torch.bfloat16)
else:
self.torch_dtype = config.torch_dtype
# Qwen2 tokenizer
self.qwen2_tokenizer = Qwen2Tokenizer.from_pretrained(
config.llm_path,
use_fast=False,
)
# self.qwen2_tokenizer.add_special_tokens({"pad_token": "<|endoftext|>")
# self.qwen2_tokenizer.padding_side = "right"
# Qwen2 Model
if config.init_from_scratch:
qwen2_config = Qwen2Config.from_pretrained(config.llm_path, torch_dtype=self.torch_dtype)
self.qwen2_model = Qwen2ForCausalLM(qwen2_config).to(self.torch_dtype)
else:
self.qwen2_model = Qwen2ForCausalLM.from_pretrained(
config.llm_path,
torch_dtype=self.torch_dtype
)
# Load and configure lora adapter
if config.lora:
self.peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=config.lora_infer_mode,
r=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=config.target_modules,
)
self.qwen2_model = get_peft_model(self.qwen2_model, self.peft_config)
# self.qwen2_model.print_trainable_parameters()
# Whisper feature extractor
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(config.whisper_path)
# Whisper encoder model
if config.init_from_scratch:
whisper_config = WhisperConfig.from_pretrained(
config.whisper_path,
torch_dtype=self.torch_dtype
)
self.whisper_encoder = WhisperModel(whisper_config).encoder.to(self.torch_dtype)
else:
self.whisper_encoder = WhisperModel.from_pretrained(
config.whisper_path,
torch_dtype=self.torch_dtype,
).encoder
self.ln_speech = nn.LayerNorm(self.whisper_encoder.config.d_model, dtype=self.torch_dtype)
# Beats model
if config.init_from_scratch:
beats_cfg = BEATsConfig()
self.beats = BEATs(beats_cfg)
else:
beats_ckpt = torch.load(config.beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
self.beats = BEATs(beats_cfg)
self.beats.load_state_dict(beats_ckpt['model'])
self.beats.to(self.torch_dtype)
self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim, dtype=self.torch_dtype)
# Q-former model
self.second_per_window = config.second_per_window
self.second_stride = config.second_stride
self.window_level_qformer, self.query_tokens = self.init_window_level_qformer(
num_query_token = config.qformer_query_token,
num_hidden_layers = config.qformer_hidden_layers,
encoder_width = self.whisper_encoder.config.d_model + self.beats.cfg.encoder_embed_dim,
)
self.window_level_qformer.bert.embeddings.word_embeddings = None
self.window_level_qformer.bert.embeddings.position_embeddings = None
for layer in self.window_level_qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.window_level_qformer.cls = None
self.qformer_qwen2_proj = nn.Linear(
self.window_level_qformer.config.hidden_size,
self.qwen2_model.config.hidden_size,
dtype=self.torch_dtype
)
def init_window_level_qformer(self, num_query_token, num_hidden_layers, encoder_width):
encoder_config = BertConfig()
encoder_config.num_hidden_layers = num_hidden_layers
encoder_config.encoder_width = encoder_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = 1
encoder_config.query_length = num_query_token
qformer = BertLMHeadModel(config=encoder_config).to(self.torch_dtype)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size, dtype=self.torch_dtype)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return qformer, query_tokens
def encode_auditory_features(self, spectrogram, raw_wave):
# whisper
spectrogram = spectrogram.to(self.torch_dtype)
speech_embeds = self.whisper_encoder(spectrogram, return_dict=True).last_hidden_state
speech_embeds = self.ln_speech(speech_embeds)
# beats
padding_mask = torch.zeros(raw_wave.shape).to(raw_wave.device).bool()
audio_embeds, _ = self.beats.extract_features(raw_wave, padding_mask=padding_mask, feature_only=True, torch_dtype=self.torch_dtype)
audio_embeds = self.ln_audio(audio_embeds)
# auditory embeds
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
speech_audio_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
# q-former
B, T, C = speech_audio_embeds.shape
kernel = round(T * self.second_per_window / 30.0)
stride = round(T * self.second_stride / 30.0)
kernel = (1, kernel)
stride = (1, stride)
speech_audio_embeds_tr = speech_audio_embeds.transpose(1, 2).unsqueeze(2)
speech_audio_embeds_overlap = F.unfold(speech_audio_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride)
_, _, L = speech_audio_embeds_overlap.shape
speech_audio_embeds_overlap = speech_audio_embeds_overlap.view(B, -1, kernel[1], L)
speech_audio_embeds_overlap = torch.permute(speech_audio_embeds_overlap, [0, 3, 2, 1])
speech_audio_embeds = speech_audio_embeds_overlap.reshape(-1, kernel[1], C)
speech_audio_atts = torch.ones(speech_audio_embeds.size()[:-1], dtype=torch.long).to(speech_audio_embeds.device)
query_tokens = self.query_tokens.expand(speech_audio_embeds.shape[0], -1, -1)
query_output = self.window_level_qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=speech_audio_embeds,
encoder_attention_mask=speech_audio_atts,
return_dict=True
)
auditory_embeds = self.qformer_qwen2_proj(query_output.last_hidden_state)
auditory_embeds = auditory_embeds.view(B, -1, auditory_embeds.size(2)).contiguous()
auditory_atts = torch.ones(auditory_embeds.size()[:-1], dtype=torch.long).to(auditory_embeds.device)
return auditory_embeds, auditory_atts
def prompt_wrap(self, auditory_embeds, auditory_atts, prompts):
# prompt_template = "<|im_start|>user\n<Speech><SpeechHere></Speech> {}<|im_end|>\n<|im_start|>assistant\n"
prompt_template = "<|im_start|>system\nYou are assistant<|im_end|>\n<|im_start|>user\n<Speech><SpeechHere></Speech> {}<|im_end|>\n<|im_start|>assistant\n"
if not isinstance(prompts, List):
prompts = [prompts]
prompt_before = []
prompt_after = []
for prompt in prompts:
before, after = prompt_template.format(prompt).split("<SpeechHere>")
prompt_before.append(before)
prompt_after.append(after)
# Prompt before <SpeechHere>
prompt_before_tokens = self.qwen2_tokenizer(
prompt_before, return_tensors="pt", add_special_tokens=False
).to(auditory_embeds.device)
prompt_before_embeds = self.qwen2_model.model.model.embed_tokens(prompt_before_tokens.input_ids)
# Prompt prompt <SpeechHere>
prompt_after_tokens = self.qwen2_tokenizer(
prompt_after, return_tensors="pt", padding="longest", add_special_tokens=False
).to(auditory_embeds.device)
prompt_after_embeds = self.qwen2_model.model.model.embed_tokens(prompt_after_tokens.input_ids)
wrapped_embeds = torch.cat([prompt_before_embeds, auditory_embeds, prompt_after_embeds], dim=1)
wrapped_atts = torch.cat([prompt_before_tokens.attention_mask, auditory_atts, prompt_after_tokens.attention_mask], dim=1)
return wrapped_embeds, wrapped_atts
def forward(
self,
raw_wave,
spectrogram,
prompt,
completion,
**kwargs
):
auditory_embeds, auditory_atts = self.encode_auditory_features(spectrogram, raw_wave)
input_embeds, input_atts = self.prompt_wrap(auditory_embeds, auditory_atts, prompt)
end_sym = self.qwen2_tokenizer.eos_token
completion = [c + end_sym for c in completion]
next_tokens = self.qwen2_tokenizer(
completion,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=512,
add_special_tokens=False
).to(spectrogram.device)
next_embeds = self.qwen2_model.model.model.embed_tokens(next_tokens.input_ids)
labels = next_tokens.input_ids.masked_fill(
next_tokens.input_ids == self.qwen2_tokenizer.pad_token_id, -100
)
empty_labels = torch.full(
(input_atts.shape[0], input_atts.shape[1]),
-100,
dtype=torch.long,
device=spectrogram.device
)
labels = torch.cat([empty_labels, labels], dim=1)
embeds = torch.cat([input_embeds, next_embeds], dim=1)
atts = torch.cat([input_atts, next_tokens.attention_mask], dim=1)
# _, _, C = logits.shape
# # Shift token < n will predict n
# shift_preds = logits[:, empty_labels.size(1) - 1: -1, :].contiguous().view(-1, C).argmax(dim=-1)
# shift_labels = labels[:, empty_labels.size(1):].contiguous().view(-1)
# mask = shift_labels != -100
# correct = (shift_preds[mask] == shift_labels[mask]).sum().float()
# total = mask.sum()
return self.qwen2_model(
inputs_embeds=embeds,
attention_mask=atts,
return_dict=True,
labels=labels,
**kwargs
)
def generate(
self,
raw_wave,
prompts,
device,
**kwargs
):
if isinstance(raw_wave, torch.Tensor):
raw_wave = raw_wave.cpu().numpy()
if raw_wave.ndim == 1:
raw_wave = np.expand_dims(raw_wave, axis=0)
spectrogram = self.feature_extractor(raw_wave, sampling_rate=16000, return_tensors="pt").input_features.to(device)
raw_wave = torch.from_numpy(raw_wave).to(device)
auditory_embeds, auditory_atts = self.encode_auditory_features(spectrogram, raw_wave)
embeds, atts = self.prompt_wrap(auditory_embeds, auditory_atts, prompts)
outputs = self.qwen2_model.generate(
inputs_embeds=embeds,
attention_mask=atts,
bos_token_id=self.qwen2_tokenizer.bos_token_id,
eos_token_id=self.qwen2_tokenizer.eos_token_id,
pad_token_id=self.qwen2_tokenizer.pad_token_id,
**kwargs
)
output_texts = self.qwen2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
return output_texts
# BEATs
class BEATsConfig:
def __init__(self, cfg=None):
self.input_patch_size: int = 16 # path size of patch embedding
self.embed_dim: int = 512 # patch embedding dimension
self.conv_bias: bool = False # include bias in conv encoder
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_wise_gradient_decay_ratio: float = 0.6 # ratio for layer-wise gradient decay
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = True # apply deep_norm first in the transformer
# dropouts
self.dropout: float = 0.0 # dropout probability for the transformer
self.attention_dropout: float = 0.0 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.05 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = True # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 800 # maximum distance for relative position embedding
self.gru_rel_pos: bool = True # apply gated relative position embedding
# label predictor
self.finetuned_model: bool = True # whether the model is a fine-tuned model.
self.predictor_dropout: float = 0.0 # dropout probability for the predictor
self.predictor_class: int = 527 # target class number for the predictor
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class BEATs(nn.Module):
def __init__(
self,
cfg: BEATsConfig,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
self.cfg = cfg
self.embed = cfg.embed_dim
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.dropout_input = nn.Dropout(cfg.dropout_input)
assert not cfg.deep_norm or not cfg.layer_norm_first
self.encoder = TransformerEncoder(cfg)
self.layer_norm = nn.LayerNorm(self.embed)
if cfg.finetuned_model:
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
else:
self.predictor = None
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) ## problem
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
feature_only=False,
torch_dtype=torch.bfloat16,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch_dtype)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(fbank, padding_mask)
fbank = fbank.unsqueeze(1)
features = self.patch_embedding(fbank)
features = features.reshape(features.shape[0], features.shape[1], -1)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x = self.dropout_input(features)
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
)
if not feature_only and self.predictor is not None:
x = self.predictor_dropout(x)
logits = self.predictor(x)
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
else:
logits = logits.mean(dim=1)
lprobs = torch.sigmoid(logits)
return lprobs, padding_mask
else:
return x, padding_mask
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
deep_norm=args.deep_norm,
has_relative_attention_bias=self.relative_position_embedding,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
encoder_layers=args.encoder_layers,
)
for i in range(args.encoder_layers)
]
)
if self.relative_position_embedding:
for i in range(1, args.encoder_layers):
del self.layers[i].self_attn.relative_attention_bias
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
self.layer_norm_first = args.layer_norm_first
self.layer_norm = nn.LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
if args.deep_norm:
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
for i in range(args.encoder_layers):
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
if self.layer_wise_gradient_decay_ratio != 1.0:
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
deep_norm: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
encoder_layers: int = 0,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
self.deep_norm = deep_norm
if self.deep_norm:
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
else:
self.deep_norm_alpha = 1
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None
):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual * self.deep_norm_alpha + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual * self.deep_norm_alpha + x
x = self.final_layer_norm(x)
return x, attn, pos_bias
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(
relative_position,
bidirectional=True
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[torch.Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
alpha = 32
q *= 1 / alpha
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.k_head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[torch.Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
attn_weights = attn_weights + attn_mask_rel_pos
attn_weights_float = F.softmax(
attn_weights, dim=-1
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[torch.Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[torch.Tensor],
prev_key_padding_mask: Optional[torch.Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[torch.Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]]
) -> Dict[str, Optional[torch.Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[torch.Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
buffer: Dict[str, Optional[torch.Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
else:
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
# Window Level Q Former
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[
:, past_key_values_length : seq_length + past_key_values_length
].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(-1, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1
)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype
) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = (
attention_scores
+ relative_position_scores_query
+ relative_position_scores_key
)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = (
self.self.attention_head_size * self.self.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if (
self.config.add_cross_attention
and layer_num % self.config.cross_attention_freq == 0
):
self.crossattention = BertAttention(
config, is_cross_attention=self.config.add_cross_attention
)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.intermediate_query = BertIntermediate(config)
self.output_query = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
() if output_attentions and self.config.add_cross_attention else None
)
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
*inputs, past_key_value, output_attentions, query_length
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int],
device: torch.device,
is_decoder: bool,
has_query: bool = False,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
if has_query: # UniLM style attention mask
causal_mask = torch.cat(
[
torch.zeros(
(batch_size, prefix_seq_len, seq_length),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=1,
)
causal_mask = torch.cat(
[
torch.ones(
(batch_size, causal_mask.shape[1], prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is None:
assert (
query_embeds is not None
), "You have to specify query_embeds when input_ids is None"
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length
if past_key_values is not None
else 0
)
query_length = query_embeds.shape[1] if query_embeds is not None else 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
query_embeds=query_embeds,
past_key_values_length=past_key_values_length,
)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if is_decoder:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask,
input_ids.shape,
device,
is_decoder,
has_query=(query_embeds is not None),
)
else:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0
].size()
else:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask) for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction="mean",
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
use_cache = False
if past_key_values is not None:
query_embeds = None
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
sequence_output = outputs[0]
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(
shifted_prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if reduction == "none":
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"query_embeds": query_embeds,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)