Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import torch | |
import loralib as lora | |
import transformers.models.whisper.modeling_whisper as whisper | |
from torch import nn | |
from transformers.activations import ACT2FN | |
from transformers import WhisperModel, AutoFeatureExtractor | |
import sys | |
from pathlib import Path | |
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]))) | |
class WhisperEncoderLayer(nn.Module): | |
def __init__(self, config, layer_idx): | |
super().__init__() | |
self.embed_dim = config.d_model | |
self.self_attn = whisper.WhisperAttention( | |
embed_dim=self.embed_dim, | |
num_heads=config.encoder_attention_heads, | |
dropout=config.attention_dropout, | |
) | |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
self.dropout = config.dropout | |
self.activation_fn = ACT2FN[config.activation_function] | |
self.activation_dropout = config.activation_dropout | |
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) | |
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) | |
self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
self.config = config | |
if layer_idx > config.encoder_layers // 2: | |
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined": | |
self.fc1 = lora.Linear(self.embed_dim, config.encoder_ffn_dim, r=config.lora_rank) | |
self.fc2 = lora.Linear(config.encoder_ffn_dim, self.embed_dim, r=config.lora_rank) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
layer_head_mask: torch.Tensor, | |
output_attentions: bool = False, | |
): | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` | |
attention_mask (`torch.FloatTensor`): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size | |
`(encoder_attention_heads,)`. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
""" | |
residual = hidden_states | |
hidden_states = self.self_attn_layer_norm(hidden_states) | |
hidden_states, attn_weights, _ = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
layer_head_mask=layer_head_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
hidden_states = residual + hidden_states | |
residual = hidden_states | |
hidden_states = self.final_layer_norm(hidden_states) | |
hidden_states = self.activation_fn(self.fc1(hidden_states)) | |
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) | |
hidden_states = self.fc2(hidden_states) | |
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
hidden_states = residual + hidden_states | |
if hidden_states.dtype == torch.float16 and ( | |
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() | |
): | |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 | |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
class WhisperWrapper(nn.Module): | |
def __init__( | |
self, | |
pretrain_model="wavlm_large", | |
hidden_dim=256, | |
finetune_method="lora", | |
lora_rank=16, | |
freeze_params=True, | |
output_class_num=4, | |
use_conv_output=True, | |
detailed_class_num=17, | |
predict_gender=False, | |
): | |
super(WhisperWrapper, self).__init__() | |
# 1. We Load the model first with weights | |
self.pretrain_model = pretrain_model | |
self.finetune_method = finetune_method | |
self.freeze_params = freeze_params | |
self.use_conv_output = use_conv_output | |
self.lora_rank = lora_rank | |
self.predict_gender = predict_gender | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny", chunk_length=15) | |
if self.pretrain_model == "whisper_tiny": | |
self.backbone_model = WhisperModel.from_pretrained( | |
"openai/whisper-tiny", | |
output_hidden_states=True, | |
ignore_mismatched_sizes=True, | |
max_source_positions=750, | |
) | |
elif self.pretrain_model == "whisper_base": | |
self.backbone_model = WhisperModel.from_pretrained( | |
"openai/whisper-base", | |
output_hidden_states=True, | |
ignore_mismatched_sizes=True, | |
max_source_positions=750, | |
) | |
elif self.pretrain_model == "whisper_small": | |
self.backbone_model = WhisperModel.from_pretrained( | |
"openai/whisper-small", | |
output_hidden_states=True, | |
max_source_positions=750, | |
ignore_mismatched_sizes=True | |
) | |
elif self.pretrain_model == "whisper_medium": | |
self.backbone_model = WhisperModel.from_pretrained( | |
"openai/whisper-medium", | |
output_hidden_states=True, | |
ignore_mismatched_sizes=True | |
) | |
elif self.pretrain_model == "whisper_large": | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3", chunk_length=15) | |
self.backbone_model = WhisperModel.from_pretrained( | |
"openai/whisper-large-v3", | |
output_hidden_states=True, | |
ignore_mismatched_sizes=True, | |
max_source_positions=750, | |
) | |
self.embed_positions = copy.deepcopy(self.backbone_model.encoder.embed_positions.weight) | |
self.embed_positions.requires_grad = False | |
state_dict = self.backbone_model.state_dict() | |
# 2. Read the model config | |
self.model_config = self.backbone_model.config | |
self.model_config.finetune_method = self.finetune_method | |
self.model_config.lora_rank = self.lora_rank | |
if self.finetune_method == "lora": | |
# 3. Config encoder layers with adapter or embedding prompt | |
# pdb.set_trace() | |
self.backbone_model.encoder.layers = nn.ModuleList( | |
[WhisperEncoderLayer(self.model_config, layer_idx) for layer_idx in range(self.model_config.encoder_layers)] | |
) | |
# 4. Load the weights back | |
msg = self.backbone_model.load_state_dict(state_dict, strict=False) | |
# 2. Freeze the weights | |
if self.freeze_params and self.finetune_method != "lora": | |
for _, p in self.backbone_model.named_parameters(): p.requires_grad = False | |
elif self.freeze_params and self.finetune_method == "lora": | |
for name, p in self.backbone_model.named_parameters(): | |
if name in msg.missing_keys: p.requires_grad = True | |
else: p.requires_grad = False | |
else: | |
for name, p in self.backbone_model.named_parameters(): | |
if "decoder" not in name and "conv1" not in name and "conv2" not in name and "embed_positions" not in name: p.requires_grad = True | |
else: p.requires_grad = False | |
# 6. Downstream models | |
self.model_seq = nn.Sequential( | |
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0) | |
) | |
if self.use_conv_output: | |
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings | |
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) | |
else: | |
num_layers = self.model_config.num_hidden_layers | |
self.weights = nn.Parameter(torch.zeros(num_layers)) | |
self.emotion_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, output_class_num), | |
) | |
self.detailed_out_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, detailed_class_num), | |
) | |
self.arousal_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 1), | |
nn.Sigmoid() | |
) | |
self.valence_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 1), | |
nn.Sigmoid() | |
) | |
self.dominance_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 1), | |
nn.Sigmoid() | |
) | |
if(self.predict_gender): | |
self.gender_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 2) | |
) | |
def forward(self, x, length=None): | |
# 1. feature extraction and projections | |
if length is not None: | |
max_audio_len = 15*16000 | |
# Append to list for feature_extractor to work | |
new_x = list() | |
for idx in range(len(length)): | |
new_x.append(x[idx].detach().cpu().numpy()) | |
# Max length is max audio len in a batch | |
features = self.feature_extractor( | |
new_x, | |
return_tensors="pt", | |
sampling_rate=16000, | |
max_length=max_audio_len | |
) | |
features = features.input_features.cuda() | |
else: | |
max_audio_len = 15*16000 | |
features = self.feature_extractor( | |
x[0].detach().cpu(), | |
return_tensors="pt", | |
sampling_rate=16000, | |
max_length=max_audio_len | |
) | |
features = features.input_features.cuda() | |
# 2. get length and mask | |
if length is not None: | |
length = self._get_feat_extract_output_lengths(length.detach().cpu()) | |
# Replace positional embeddings | |
self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750]) | |
else: | |
# Replace positional embeddings | |
length = torch.tensor([len(x[0])]) | |
length = self._get_feat_extract_output_lengths(length) | |
self.backbone_model.encoder.embed_positions = self.backbone_model.encoder.embed_positions.from_pretrained(self.embed_positions[:750]) | |
# 3. transformer encoding features | |
# compute reduced attention_mask corresponding to feature vectors | |
features = self.backbone_model.encoder( | |
features, output_hidden_states=True | |
).hidden_states | |
features = torch.stack(features, dim=0)[-1] | |
# 6. Pass the weighted average to point-wise 1D Conv | |
# B x T x D | |
features = features.transpose(1, 2) | |
features = self.model_seq(features) | |
features = features.transpose(1, 2) | |
# 7. Pooling | |
if length is not None: | |
mean, std = list(), list() | |
for snt_id in range(features.shape[0]): | |
# Avoiding padded time steps | |
actual_size = length[snt_id] | |
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0)) | |
features = torch.stack(mean) | |
else: | |
features = torch.mean(features, dim=1) | |
# Output predictions | |
# B x D | |
arousal = self.arousal_layer(features) | |
valence = self.valence_layer(features) | |
dominance = self.dominance_layer(features) | |
if(self.predict_gender): | |
gender_outputs = self.gender_layer(features) | |
return arousal, valence, dominance, gender_outputs | |
return arousal, valence, dominance | |
# From huggingface | |
def _get_feat_extract_output_lengths(self, input_lengths): | |
""" | |
Computes the output length of the convolutional layers | |
""" | |
input_lengths = input_lengths // 160 | |
input_lengths = (input_lengths - 1) // 2 + 1 | |
return input_lengths | |