File size: 4,540 Bytes
9a7fe1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn as nn
import torchaudio
from .clap_modules.open_clip import create_model
from .clap_modules.training.data import get_audio_features
from ..common.get_model import register
@register('clap_audio')
class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
"""Uses the CLAP audio encoder"""
def __init__(
self,
pretrained_path="",
key="waveform",
sampling_rate=16000,
embed_mode="audio",
unconditional_prob=0.1,
random_mute=False,
max_random_mute_portion=0.5,
training_mode=True,
joint_embed_shape=768,
embed_shape=512,
num_layers=12,
depths=[2, 2, 6, 2],
amodel="HTSAT-large",
):
super().__init__()
self.key = key
self.amodel = amodel # or 'PANN-14'
self.tmodel = "roberta" # the best text encoder in our training
self.enable_fusion = False # False if you do not want to use the fusion model
self.fusion_type = "aff_2d"
self.pretrained = pretrained_path
self.embed_mode = embed_mode
self.embed_mode_orig = embed_mode
self.sampling_rate = sampling_rate
self.unconditional_prob = unconditional_prob
self.random_mute = random_mute
self.joint_embed_shape = joint_embed_shape
self.max_random_mute_portion = max_random_mute_portion
self.training_mode = training_mode
self.model, self.model_cfg = create_model(
self.amodel,
self.tmodel,
self.pretrained,
precision="fp32",
device="cpu",
enable_fusion=self.enable_fusion,
fusion_type=self.fusion_type,
joint_embed_shape=self.joint_embed_shape,
)
def get_dtype(self):
return next(self.model.parameters()).dtype
def get_unconditional_condition(self, batchsize):
self.unconditional_token = self.model.get_text_embedding(
self.tokenizer(["", ""])
)[0:1]
return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
def batch_to_list(self, batch):
ret = []
for i in range(batch.size(0)):
ret.append(batch[i])
return ret
def make_decision(self, probability):
if float(torch.rand(1)) < probability:
return True
else:
return False
def random_uniform(self, start, end):
val = torch.rand(1).item()
return start + (end - start) * val
def _random_mute(self, waveform):
# waveform: [bs, t-steps]
t_steps = waveform.size(-1)
for i in range(waveform.size(0)):
mute_size = int(
self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
)
mute_start = int(self.random_uniform(0, t_steps - mute_size))
waveform[i, mute_start : mute_start + mute_size] = 0
return waveform
def cos_similarity(self, waveform, text):
# waveform: [bs, t_steps]
with torch.no_grad():
self.embed_mode = "audio"
audio_emb = self(waveform.cuda())
self.embed_mode = "text"
text_emb = self(text)
similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
return similarity.squeeze()
def forward(self, batch, key=None):
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
if self.embed_mode == "audio":
audio_dict_list = []
assert (
self.sampling_rate == 16000
), "We only support 16000 sampling rate"
# batch: [bs, 1, t-samples]
batch = torchaudio.functional.resample(
batch, orig_freq=self.sampling_rate, new_freq=48000
)
for waveform in self.batch_to_list(batch):
audio_dict = {}
audio_dict = get_audio_features(
audio_dict,
waveform.squeeze(),
480000,
data_truncating="fusion",
data_filling="repeatpad",
audio_cfg=self.model_cfg["audio_cfg"],
dtype=self.get_dtype(),
)
audio_dict_list.append(audio_dict)
# [bs, 768]
embed = self.model.get_audio_embedding(audio_dict_list)
embed = embed.unsqueeze(1)
# [bs, 1, 768]
return embed
|