Spaces:
Sleeping
Sleeping
import os | |
from abc import ABC | |
from typing import Literal, Optional | |
from df.enhance import enhance, init_df, load_audio, save_audio | |
from pydub import AudioSegment | |
def convert_to_wav(input_file: str, output_file: str): | |
"""Convert an audio file to WAV format | |
Args: | |
input_file (str): path to input audio file | |
output_file (str): path to output WAV file | |
""" | |
# Detect the format of the input file | |
format = input_file.split(".")[-1].lower() | |
# Read the audio file | |
audio = AudioSegment.from_file(input_file, format=format) | |
# Export as WAV | |
audio.export(output_file, format="wav") | |
def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str: | |
"""Generate the output file path | |
Args: | |
audio_file (str): path to input audio file | |
tag (str): tag to append to the output file name | |
ext (str, optional): extension of the output file. Defaults to None. | |
Returns: | |
str: path to output file | |
""" | |
directory = "./enhanced" | |
# Get the name of the input file | |
filename = os.path.basename(audio_file) | |
# Get the name of the input file without the extension | |
filename_without_extension = os.path.splitext(filename)[0] | |
# Get the extension of the input file | |
extension = ext or os.path.splitext(filename)[1] | |
# Generate the output file path | |
output_file = os.path.join(directory, filename_without_extension + tag + extension) | |
return output_file | |
class BaseEnhancer(ABC): | |
"""Base class for audio enhancers""" | |
def __init__(self, *args, **kwargs): | |
raise NotImplementedError | |
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: | |
raise NotImplementedError | |
def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str: | |
output_file = make_output_file_path(audio_file, tag, ext=ext) | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
return output_file | |
class DFEnhancer(BaseEnhancer): | |
def __init__(self, *args, **kwargs): | |
self.model, self.df_state, _ = init_df() | |
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: | |
output_file = output_file or self.get_output_file(audio_file, "_df") | |
audio, _ = load_audio(audio_file, sr=self.df_state.sr()) | |
enhanced = enhance(self.model, self.df_state, audio) | |
save_audio(output_file, enhanced, self.df_state.sr()) | |
return output_file | |
def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer: | |
"""Get an audio enhancer | |
Args: | |
enhancer_name (Literal["df"]): name of the audio enhancer | |
Raises: | |
ValueError: if the enhancer name is not recognised | |
Returns: | |
BaseEnhancer: audio enhancer | |
""" | |
if enhancer_name == "df": | |
import warnings | |
warnings.filterwarnings( | |
"ignore", | |
message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.', | |
) | |
return DFEnhancer() | |
else: | |
raise ValueError(f"Unknown enhancer name: {enhancer_name}") | |