yizhilll commited on
Commit
806f23c
1 Parent(s): e210c1a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -0
README.md CHANGED
@@ -14,6 +14,59 @@ The training settings and model usage of MERT-v0-public can be referred to the [
14
 
15
  Details are reported at the short article *Large-Scale Pretrained Model for Self-Supervised Music Audio Representation Learning*.
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Citation
19
  ```shell
 
14
 
15
  Details are reported at the short article *Large-Scale Pretrained Model for Self-Supervised Music Audio Representation Learning*.
16
 
17
+ # Demo code
18
+
19
+ ```python
20
+ from transformers import Wav2Vec2FeatureExtractor
21
+ from transformers import AutoModel
22
+ import torch
23
+ from torch import nn
24
+ import torchaudio.transforms as T
25
+ from datasets import load_dataset
26
+
27
+
28
+ # loading our model weights
29
+ model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
30
+ # loading the corresponding preprocessor config
31
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
32
+
33
+ # load demo audio and set processor
34
+ dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
35
+ dataset = dataset.sort("id")
36
+ sampling_rate = dataset.features["audio"].sampling_rate
37
+
38
+ resample_rate = processor.sampling_rate
39
+ # make sure the sample_rate aligned
40
+ if resample_rate != sampling_rate:
41
+ print(f'setting rate from {sampling_rate} to {resample_rate}')
42
+ resampler = T.Resample(sampling_rate, resample_rate)
43
+ else:
44
+ resampler = None
45
+
46
+ # audio file is decoded on the fly
47
+ if resampler is None:
48
+ input_audio = dataset[0]["audio"]["array"]
49
+ else:
50
+ input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
51
+
52
+ inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
53
+ with torch.no_grad():
54
+ outputs = model(**inputs, output_hidden_states=True)
55
+
56
+ # take a look at the output shape, there are 13 layers of representation
57
+ # each layer performs differently in different downstream tasks, you should choose empirically
58
+ all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
59
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
60
+
61
+ # for utterance level classification tasks, you can simply reduce the representation in time
62
+ time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
63
+ print(time_reduced_hidden_states.shape) # [13, 768]
64
+
65
+ # you can even use a learnable weighted average representation
66
+ aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
67
+ weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
68
+ print(weighted_avg_hidden_states.shape) # [768]
69
+ ```
70
 
71
  # Citation
72
  ```shell