Spaces:
Sleeping
Sleeping
import torchaudio | |
from torch import mean as _mean | |
from torch import hamming_window, log10, no_grad, exp | |
def return_input(user_input): | |
if user_input is None: | |
return None | |
return user_input | |
def stereo_to_mono_convertion(waveform): | |
if waveform.shape[0] > 1: | |
waveform = _mean(waveform, dim=0, keepdims=True) | |
return waveform | |
else: | |
return waveform | |
def load_audio(audio_path): | |
audio_tensor, sr = torchaudio.load(audio_path) | |
audio_tensor = stereo_to_mono_convertion(audio_tensor) | |
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) | |
return audio_tensor | |
def load_audio_numpy(audio_path): | |
audio_tensor, sr = torchaudio.load(audio_path) | |
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) | |
audio_array = audio_tensor.numpy() | |
return (16000, audio_array.ravel()) | |
def audio_to_spectrogram(audio): | |
transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window) | |
spectrogram = transform_fn(audio) | |
return spectrogram | |
def extract_magnitude_and_phase(spectrogram): | |
magnitude, phase = spectrogram.abs(), spectrogram.angle() | |
return magnitude, phase | |
def amplitude_to_db(magnitude_spec): | |
max_amplitude = magnitude_spec.max() | |
db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0) | |
return db_spectrogram, max_amplitude | |
def min_max_scaling(spectrogram, scaler): | |
# Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling) | |
spectrogram = scaler.transform(spectrogram) | |
return spectrogram | |
def inverse_min_max(spectrogram, scaler): | |
spectrogram = scaler.inverse_transform(spectrogram) | |
return spectrogram | |
def db_to_amplitude(db_spectrogram, max_amplitude): | |
return max_amplitude * 10**(db_spectrogram/20) | |
def reconstruct_complex_spectrogram(magnitude, phase): | |
return magnitude * exp(1j*phase) | |
def inverse_fft(spectrogram): | |
inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window) | |
return inverse_fn(spectrogram) | |
def transform_audio(audio, scaler): | |
spectrogram = audio_to_spectrogram(audio) | |
magnitude, phase = extract_magnitude_and_phase(spectrogram) | |
db_spectrogram, max_amplitude = amplitude_to_db(magnitude) | |
db_spectrogram = min_max_scaling(db_spectrogram, scaler) | |
return db_spectrogram.unsqueeze(0), phase, max_amplitude | |
def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude): | |
db_spectrogram = db_spectrogram.squeeze(0) | |
db_spectrogram = inverse_min_max(db_spectrogram, scaler) | |
spectrogram = db_to_amplitude(db_spectrogram, max_amplitude) | |
complex_spec = reconstruct_complex_spectrogram(spectrogram, phase) | |
audio = inverse_fft(complex_spec) | |
return audio | |
def save_audio(audio): | |
torchaudio.save(r"enhanced_audio.wav", audio, 16000) | |
return r"enhanced_audio.wav" | |
def predict(user_input, model, scaler): | |
audio = load_audio(user_input) | |
spectrogram, phase, max_amplitude = transform_audio(audio, scaler) | |
with no_grad(): | |
enhanced_spectrogram = model.forward(spectrogram) | |
enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude) | |
enhanced_audio_path = save_audio(enhanced_audio) | |
return enhanced_audio_path |