Spaces:
Sleeping
Sleeping
File size: 3,219 Bytes
9e34a62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from abc import ABC
from typing import Literal, Optional
from df.enhance import enhance, init_df, load_audio, save_audio
from pydub import AudioSegment
def convert_to_wav(input_file: str, output_file: str):
"""Convert an audio file to WAV format
Args:
input_file (str): path to input audio file
output_file (str): path to output WAV file
"""
# Detect the format of the input file
format = input_file.split(".")[-1].lower()
# Read the audio file
audio = AudioSegment.from_file(input_file, format=format)
# Export as WAV
audio.export(output_file, format="wav")
def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
"""Generate the output file path
Args:
audio_file (str): path to input audio file
tag (str): tag to append to the output file name
ext (str, optional): extension of the output file. Defaults to None.
Returns:
str: path to output file
"""
directory = "./enhanced"
# Get the name of the input file
filename = os.path.basename(audio_file)
# Get the name of the input file without the extension
filename_without_extension = os.path.splitext(filename)[0]
# Get the extension of the input file
extension = ext or os.path.splitext(filename)[1]
# Generate the output file path
output_file = os.path.join(directory, filename_without_extension + tag + extension)
return output_file
class BaseEnhancer(ABC):
"""Base class for audio enhancers"""
def __init__(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
raise NotImplementedError
def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
output_file = make_output_file_path(audio_file, tag, ext=ext)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
return output_file
class DFEnhancer(BaseEnhancer):
def __init__(self, *args, **kwargs):
self.model, self.df_state, _ = init_df()
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
output_file = output_file or self.get_output_file(audio_file, "_df")
audio, _ = load_audio(audio_file, sr=self.df_state.sr())
enhanced = enhance(self.model, self.df_state, audio)
save_audio(output_file, enhanced, self.df_state.sr())
return output_file
def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""Get an audio enhancer
Args:
enhancer_name (Literal["df"]): name of the audio enhancer
Raises:
ValueError: if the enhancer name is not recognised
Returns:
BaseEnhancer: audio enhancer
"""
if enhancer_name == "df":
import warnings
warnings.filterwarnings(
"ignore",
message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
)
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
|