File size: 2,498 Bytes
1eaf59a
5247bff
 
 
 
 
 
1eaf59a
5247bff
1eaf59a
5247bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T

# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py

inputs = [gr.components.Audio(type="filepath", label="Add music audio file"), 
          gr.inputs.Audio(source="microphone",optional=True, type="filepath"),
          ]
outputs = [gr.components.Textbox()]
# outputs = [gr.components.Textbox(), transcription_df]
title = "Output the tags of a (music) audio"
description = "An example of using MERT-95M-public to conduct music tagging."
article = ""
audio_examples = [
    ["input/example-1.wav"],
    ["input/example-2.wav"],
]

# Load the model
model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)


def convert_audio(inputs, microphone):
    if (microphone is not None):
        inputs = microphone
    
    waveform, sample_rate = torchaudio.load(inputs)
    

    resample_rate = processor.sampling_rate

    # make sure the sample_rate aligned
    if resample_rate != sample_rate:
        print(f'setting rate from {sample_rate} to {resample_rate}')
        resampler = T.Resample(sample_rate, resample_rate)
        waveform = resampler(waveform)
    
    inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # take a look at the output shape, there are 13 layers of representation
    # each layer performs differently in different downstream tasks, you should choose empirically
    all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
    # print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
    return str(all_layer_hidden_states.shape)


# iface = gr.Interface(fn=convert_audio, inputs="audio", outputs="text")
# iface.launch()

audio_chunked = gr.Interface(
    fn=convert_audio,
    inputs=inputs,
    outputs=outputs,
    allow_flagging="never",
    title=title,
    description=description,
    article=article,
    examples=audio_examples,
)


demo = gr.Blocks()
with demo:
    gr.TabbedInterface([audio_chunked], [
        "Audio File"])
# demo.queue(concurrency_count=1, max_size=5)
demo.launch(show_api=False)