File size: 6,840 Bytes
283e8f1
493ca62
283e8f1
 
 
 
 
 
 
 
 
 
cb013a1
 
 
8fdee99
283e8f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06fa0ef
3462d76
 
 
 
 
 
 
 
 
 
 
e85f544
cb013a1
1a804e1
ebd2b6f
cb013a1
283e8f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479afcd
 
fc2d9fe
283e8f1
 
 
 
 
 
fc2d9fe
283e8f1
 
 
 
493ca62
283e8f1
 
 
 
 
 
cb013a1
283e8f1
 
 
9e5a49b
 
 
283e8f1
 
 
 
 
 
 
 
 
 
 
 
e85f544
cb013a1
 
283e8f1
 
 
 
 
 
cb013a1
283e8f1
 
 
 
 
 
e85f544
 
 
cb013a1
 
 
 
 
 
 
 
 
 
 
 
e85f544
cb013a1
 
e85f544
cb013a1
565651e
679dd6b
da5e444
 
679dd6b
5c07d56
0983573
5c07d56
 
 
da5e444
5c07d56
 
 
 
 
 
 
 
 
 
cb013a1
493ca62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr

from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
import logging

import json
import os
import re

import pandas as pd
import librosa

import importlib 
modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")

from Prediction_Head.MTGGenre_head import MLPProberBase 


logger = logging.getLogger("MERT-v1-95M-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)

title = "One Model for All Music Understanding Tasks"
description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal representation. \n Due the hardware limitation of the machine hosting this demo (2 CPU and 16GB RAM) only the first 4 seconds of audio are used!"
with open('./README.md', 'r') as f:
    # skip the header
    header_count = 0
    for line in f:
        if '---' in line:
            header_count += 1
        if header_count >= 2:
            break
    # read the rest conent
    article = f.read()
    
df_init = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
transcription_df = gr.DataFrame(value=df_init, label="Output Dataframe", row_count=(
    0, "dynamic"), wrap=True)
outputs = transcription_df

# model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v1-95M")
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v1-95M")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

MERT_BEST_LAYER_IDX = {
    'EMO': 5,
    'GS': 8,
    'GTZAN': 7,
    'MTGGenre': 7,
    'MTGInstrument': 'all',
    'MTGMood': 6,
    'MTGTop50': 6,
    'MTT': 'all',
    'NSynthI': 6,
    'NSynthP': 1,
    'VocalSetS': 2,
    'VocalSetT': 9,
} 

MERT_BEST_LAYER_IDX = {
    'EMO': 5,
    'GS': 8,
    'GTZAN': 7,
    'MTGGenre': 7,
    'MTGInstrument': 'all',
    'MTGMood': 6,
    'MTGTop50': 6,
    'MTT': 'all',
    'NSynthI': 6,
    'NSynthP': 1,
    'VocalSetS': 2,
    'VocalSetT': 9,
} 
CLASSIFIERS = {

}

ID2CLASS = {

}

TASKS = ['GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT','EMO',]
Regression_TASKS = ['EMO']
head_dir = './Prediction_Head/best-layer-MERT-v1-95M'
for task in TASKS:
    print('loading', task)
    with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
        ID2CLASS[task]=json.load(f)
    num_class = len(ID2CLASS[task].keys())
    CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
    CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict'])
    CLASSIFIERS[task].to(device)

model.to(device)

def model_inference(inputs):
    waveform, sample_rate = torchaudio.load(inputs)

    resample_rate = processor.sampling_rate

    # make sure the sample_rate aligned
    if resample_rate != sample_rate:
        # print(f'setting rate from {sample_rate} to {resample_rate}')
        resampler = T.Resample(sample_rate, resample_rate)
        waveform = resampler(waveform)
    
    #waveform = waveform.view(-1,) # make it (n_sample, )
    waveform = waveform[0][0:4*resample_rate]  # cut to 4s samples
    
    model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
    model_inputs.to(device)
    with torch.no_grad():
        model_outputs = model(**model_inputs, output_hidden_states=True)

    # take a look at the output shape, there are 13 layers of representation
    # each layer performs differently in different downstream tasks, you should choose empirically
    all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:,:,:].unsqueeze(0)
    print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
    all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)

    task_output_texts = ""
    df = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
    df_objects = []

    for task in TASKS:
        num_class = len(ID2CLASS[task].keys())
        if MERT_BEST_LAYER_IDX[task] == 'all':
            logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87]
        else:
            logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
        # print(f'task {task} logits:', logits.shape, 'num class:', num_class)
        
        sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1 
        sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
        # print(sorted_prob)
        # print(sorted_prob.shape)
        
        top_n_show = 5 if num_class >= 5 else num_class
        # task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
        # task_output_texts = task_output_texts + '----------------------\n'

        row_elements = [task]
        for idx in range(top_n_show):
            print(ID2CLASS[task])
            # print('id', str(sorted_idx[idx].item()))
            output_class_name = str(ID2CLASS[task][str(sorted_idx[idx].item())])
            output_class_name = re.sub(r'^\w+---', '', output_class_name)
            output_class_name = re.sub(r'^\w+\/\w+---', '', output_class_name)
            # print('output name', output_class_name)
            output_prob = f' {sorted_prob[idx].item():.2%}'
            row_elements.append(output_class_name+output_prob)
        # fill empty elment
        for _ in range(5+1 - len(row_elements)):
            row_elements.append(' ')
        df_objects.append(row_elements)
    df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
    return df

def convert_audio(inputs):
    #audio_data, sample_rate = librosa.load(inputs, sr=None)
    return model_inference(inputs)    

def build_audio_flow(title, description, article):
    audio_file = gr.File(label="Select Audio File (*.wav)")

    demo =  gr.Interface(
        fn=convert_audio,
        inputs=audio_file,
        outputs=outputs,
        allow_flagging="never",
        title=title,
        description=description,
        article=article,
    )

    return demo

demo = build_audio_flow(title, description, article)
# demo.queue(concurrency_count=1, max_size=5)
demo.launch()