File size: 3,527 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from musiclm_pytorch import MuLaNEmbedder
import hydra
import librosa
import torch
import yaml
from prodict import Prodict

from musiclm_pytorch import AudioSpectrogramTransformerPretrained, TextTransformerPretrained, MuLaN
from omegaconf import DictConfig
import os

def get_pretrained_config(root, name):
    if root is None:
        return name
    path = os.path.join(root, name)
    #获取snapshots目录下的目录
    config_dir = os.path.join(path, 'snapshots')
    config_files = os.listdir(config_dir)
    assert len(config_files) == 1
    config_path = os.path.join(config_dir, config_files[0])
    return config_path
        
def create_MuLaN_from_config(config: DictConfig):
    """
    Create a MuLaN model from a configuration file.
    """
    pretraind_root = config.model.pretraind_root

    audio_model_name = get_pretrained_config(pretraind_root, config.model.audio_model.name)
    audio_transformer = AudioSpectrogramTransformerPretrained(
        model_name = audio_model_name, 
        model_dim = config.model.audio_model.model_dim,
        use_layer_idx = config.model.audio_model.use_layer_idx,
        **config.model.audio_transformer
    )
    text_model_name = get_pretrained_config(pretraind_root, config.model.text_model.name)
    text_transformer = TextTransformerPretrained(
        model_name = text_model_name, 
        **config.model.text_transformer
    )

    mulan = MuLaN(
        audio_transformer = audio_transformer,
        text_transformer = text_transformer,
        **config.model.mulan
    )

    return mulan


def create_CLAP_model( model_kwargs = {}, ckpt_path = None ):
    from musiclm_pytorch import SoftmaxContrastiveLearning
    import laion_clap
    
    from torch import nn
    import torch
    from torchaudio.functional import resample

    import numpy as np

    from functools import partial

    # quantization
    def int16_to_float32(x):
        return (x / 32767.0).float()

    def float32_to_int16(x):
        x = torch.clip(x, min=-1., max=1.)
        return (x * 32767.).int()

    model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs)
    if ckpt_path is not None:
        model.load_ckpt(ckpt_path)
    else:
        model.load_ckpt()

    class CLAP_Model(nn.Module):
        def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True):
            super().__init__()
            self.model = model
            self.model.eval()
            self.orig_sr = sr

            klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) 
            self.contrast = klass() 

        
        def forward(self, wavs, raw_texts):
            with torch.no_grad():
                wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000)))
                audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float()
                text_latents = model.get_text_embedding(raw_texts, use_tensor=True)
            cl_loss = self.contrast(audio_latents, text_latents)
            return cl_loss
    
    clap = CLAP_Model(model)
    return clap

def get_mulan(config):
    with open(config, "r") as stream:
        mulan_config = yaml.safe_load(stream)
        mulan_config = Prodict.from_dict(mulan_config)
    ckpt_path = mulan_config.checkpoint_path
    mulan = create_MuLaN_from_config(mulan_config)
    mulan_embedder = MuLaNEmbedder(mulan, checkpoint_path = ckpt_path)
    mulan_embedder.eval()

    return mulan_embedder