Epsilon617 commited on
Commit
92cd759
·
1 Parent(s): c2c7513

add genre prediction head

Browse files
Prediction_Head/MTGGenre_head.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ class MLPProberBase(nn.Module):
6
+ def __init__(self, d=768, num_outputs=87):
7
+ super().__init__()
8
+ self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
9
+ self.num_layers = len(self.hidden_layer_sizes)
10
+ for i, ld in enumerate(self.hidden_layer_sizes):
11
+ setattr(self, f"hidden_{i}", nn.Linear(d, ld))
12
+ d = ld
13
+ self.output = nn.Linear(d, num_outputs)
14
+
15
+ def forward(self, x):
16
+ for i in range(self.num_layers):
17
+ x = getattr(self, f"hidden_{i}")(x)
18
+ # x = self.dropout(x)
19
+ x = F.relu(x)
20
+ output = self.output(x)
21
+ return output
Prediction_Head/MTGGenre_id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "genre---popfolk", "4": "genre---disco", "5": "genre---funk", "6": "genre---rnb", "7": "genre---ambient", "8": "genre---chillout", "9": "genre---downtempo", "10": "genre---easylistening", "11": "genre---electronic", "12": "genre---lounge", "13": "genre---triphop", "14": "genre---breakbeat", "15": "genre---techno", "16": "genre---newage", "17": "genre---jazz", "18": "genre---metal", "19": "genre---industrial", "20": "genre---instrumentalrock", "21": "genre---minimal", "22": "genre---alternative", "23": "genre---experimental", "24": "genre---drumnbass", "25": "genre---soul", "26": "genre---fusion", "27": "genre---soundtrack", "28": "genre---electropop", "29": "genre---world", "30": "genre---ethno", "31": "genre---trance", "32": "genre---orchestral", "33": "genre---grunge", "34": "genre---chanson", "35": "genre---worldfusion", "36": "genre---hiphop", "37": "genre---groove", "38": "genre---instrumentalpop", "39": "genre---blues", "40": "genre---reggae", "41": "genre---dance", "42": "genre---club", "43": "genre---punkrock", "44": "genre---folk", "45": "genre---synthpop", "46": "genre---poprock", "47": "genre---choir", "48": "genre---symphonic", "49": "genre---indie", "50": "genre---progressive", "51": "genre---acidjazz", "52": "genre---contemporary", "53": "genre---newwave", "54": "genre---dub", "55": "genre---rocknroll", "56": "genre---hard", "57": "genre---hardrock", "58": "genre---house", "59": "genre---atmospheric", "60": "genre---psychedelic", "61": "genre---improvisation", "62": "genre---country", "63": "genre---electronica", "64": "genre---rap", "65": "genre---60s", "66": "genre---70s", "67": "genre---darkambient", "68": "genre---idm", "69": "genre---latin", "70": "genre---postrock", "71": "genre---bossanova", "72": "genre---singersongwriter", "73": "genre---darkwave", "74": "genre---swing", "75": "genre---medieval", "76": "genre---celtic", "77": "genre---eurodance", "78": "genre---classicrock", "79": "genre---dubstep", "80": "genre---bluesrock", "81": "genre---edm", "82": "genre---deephouse", "83": "genre---jazzfusion", "84": "genre---alternativerock", "85": "genre---80s", "86": "genre---90s"}
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc ADDED
Binary file (1.08 kB). View file
 
Prediction_Head/best_MTGGenre.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
3
+ size 1759067
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
app.py CHANGED
@@ -8,9 +8,12 @@ import torchaudio
8
  import torchaudio.transforms as T
9
  import logging
10
 
 
 
11
  import importlib
12
  modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
13
 
 
14
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
15
 
16
 
@@ -34,7 +37,7 @@ live_inputs = [
34
  ]
35
  # outputs = [gr.components.Textbox()]
36
  # outputs = [gr.components.Textbox(), transcription_df]
37
- title = "Output the tags of a (music) audio"
38
  description = "An example of using MERT-95M-public to conduct music tagging."
39
  article = ""
40
  audio_examples = [
@@ -48,9 +51,17 @@ audio_examples = [
48
  model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
49
  processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
50
 
 
 
 
 
 
 
 
51
 
52
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
  model.to(device)
 
54
 
55
  def convert_audio(inputs, microphone):
56
  if (microphone is not None):
@@ -75,10 +86,17 @@ def convert_audio(inputs, microphone):
75
  # take a look at the output shape, there are 13 layers of representation
76
  # each layer performs differently in different downstream tasks, you should choose empirically
77
  all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
78
- # print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
 
 
 
 
 
 
79
  # logger.warning(all_layer_hidden_states.shape)
80
 
81
- return f"device {device}\n sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
 
82
 
83
  def live_convert_audio(microphone):
84
  if (microphone is not None):
@@ -103,10 +121,17 @@ def live_convert_audio(microphone):
103
  # take a look at the output shape, there are 13 layers of representation
104
  # each layer performs differently in different downstream tasks, you should choose empirically
105
  all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
106
- # print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
 
 
 
 
 
 
107
  # logger.warning(all_layer_hidden_states.shape)
108
 
109
- return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
 
110
 
111
 
112
  audio_chunked = gr.Interface(
 
8
  import torchaudio.transforms as T
9
  import logging
10
 
11
+ import json
12
+
13
  import importlib
14
  modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
15
 
16
+ from Prediction_Head.MTGGenre_head import MLPProberBase
17
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
18
 
19
 
 
37
  ]
38
  # outputs = [gr.components.Textbox()]
39
  # outputs = [gr.components.Textbox(), transcription_df]
40
+ title = "Predict the top 5 possible genres of Music"
41
  description = "An example of using MERT-95M-public to conduct music tagging."
42
  article = ""
43
  audio_examples = [
 
51
  model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
52
  processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
53
 
54
+ MERT_LAYER_IDX = 7
55
+ MTGGenre_classifier = MLPProberBase()
56
+ MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict'])
57
+
58
+ with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f:
59
+ id2cls=json.load(f)
60
+
61
 
62
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
63
  model.to(device)
64
+ MTGGenre_classifier.to(device)
65
 
66
  def convert_audio(inputs, microphone):
67
  if (microphone is not None):
 
86
  # take a look at the output shape, there are 13 layers of representation
87
  # each layer performs differently in different downstream tasks, you should choose empirically
88
  all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
89
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
90
+
91
+ logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
92
+ print(logits.shape)
93
+ sorted_idx = torch.argsort(logits, dim = -1, descending=True)
94
+
95
+ output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
96
  # logger.warning(all_layer_hidden_states.shape)
97
 
98
+ # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
99
+ return f"device: {device}\n" + output_texts
100
 
101
  def live_convert_audio(microphone):
102
  if (microphone is not None):
 
121
  # take a look at the output shape, there are 13 layers of representation
122
  # each layer performs differently in different downstream tasks, you should choose empirically
123
  all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
124
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
125
+
126
+ logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
127
+ print(logits.shape)
128
+ sorted_idx = torch.argsort(logits, dim = -1, descending=True)
129
+
130
+ output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
131
  # logger.warning(all_layer_hidden_states.shape)
132
 
133
+ # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
134
+ return f"device: {device}\n" + output_texts
135
 
136
 
137
  audio_chunked = gr.Interface(