File size: 3,349 Bytes
564c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b634a74
564c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b634a74
564c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/python3
import os
import torch
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
from pyharp import ModelCard, build_endpoint
from audiotools import AudioSignal
import scipy
import torch
import gradio as gr

card = ModelCard(
    name='Versatile Audio Super Resolution',
    description='Upsample audio and predict upper spectrum.',
    author='Team Audio',
    tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion']
)

os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
latent_t_per_second=12.8  # not sure about this??

audiosr = build_model(model_name="basic", device="auto")

def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps):
    """
    This function defines the audio processing steps

    Args:
        input_audio_path (str): the audio filepath to be processed.

        <YOUR_KWARGS>: additional keyword arguments necessary for processing.
            NOTE: These should correspond to and match order of UI elements defined below.

    Returns:
        output_audio_path (str): the filepath of the processed audio.
    """

    sig = AudioSignal(input_audio_path, sample_rate=48000)    

    outfile = "./output.wav"

    audio_concat = None

    total_length = sig.duration
    num_segs = int(total_length / 10) #10 second segments
    remainder = total_length % 10 # duration of last segment

    for audio_segment in range(num_segs):
        start = audio_segment * 10

        if audio_segment == num_segs - 1:
            end = start + remainder
        else:
            end = start + 10

        # get segment of audio from original file
        sig_seg = sig[start*sig.sample_rate:int(end*sig.sample_rate)]  # int accounts for float end time on last seg
        sig_seg.write("temp.wav")
        audio = super_resolution(
            audiosr,
            "temp.wav",
            seed=seed,
            guidance_scale=guidance_scale,
            ddim_steps=num_inference_steps,
            latent_t_per_second=latent_t_per_second
        )

        #save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate)

        if audio_concat is None:
            audio_concat = audio
            #audio_concat = audio[0]
        else:
            audio_concat = scipy.concatenate((audio_concat, audio))

    scipy.io.wavfile.write(outfile, rate=sig.sample_rate, data=audio_concat)
    return outfile

# Build the endpoint
with gr.Blocks() as webapp:
    # Define your Gradio interface
    inputs = [
        gr.Audio(
            label="Audio Input", 
            type="filepath"
        ), 
        gr.Slider(
            label="seed",
            minimum="0",
            maximum="65535",
            value="0",
            step="1"
        ),
        gr.Slider(
            minimum=0, maximum=10, 
            value=3.5, 
            label="Guidance Scale"
        ),
        gr.Slider(
            minimum=1, maximum=500, 
            step=1, value=50, 
            label="Inference Steps"
        ),
    ]

    # make an output audio widget
    output = gr.Audio(label="Audio Output", type="filepath")

    # Build the endpoint
    ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card)

#webapp.queue()
webapp.launch(share=True)