File size: 3,219 Bytes
9e34a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from abc import ABC
from typing import Literal, Optional

from df.enhance import enhance, init_df, load_audio, save_audio
from pydub import AudioSegment


def convert_to_wav(input_file: str, output_file: str):
    """Convert an audio file to WAV format

    Args:
        input_file (str): path to input audio file
        output_file (str): path to output WAV file

    """
    # Detect the format of the input file
    format = input_file.split(".")[-1].lower()

    # Read the audio file
    audio = AudioSegment.from_file(input_file, format=format)

    # Export as WAV
    audio.export(output_file, format="wav")


def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
    """Generate the output file path

    Args:
        audio_file (str): path to input audio file
        tag (str): tag to append to the output file name
        ext (str, optional): extension of the output file. Defaults to None.

    Returns:
        str: path to output file
    """

    directory = "./enhanced"
    # Get the name of the input file
    filename = os.path.basename(audio_file)

    # Get the name of the input file without the extension
    filename_without_extension = os.path.splitext(filename)[0]

    # Get the extension of the input file
    extension = ext or os.path.splitext(filename)[1]

    # Generate the output file path
    output_file = os.path.join(directory, filename_without_extension + tag + extension)

    return output_file


class BaseEnhancer(ABC):
    """Base class for audio enhancers"""

    def __init__(self, *args, **kwargs):
        raise NotImplementedError

    def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
        raise NotImplementedError

    def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
        output_file = make_output_file_path(audio_file, tag, ext=ext)
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        return output_file


class DFEnhancer(BaseEnhancer):
    def __init__(self, *args, **kwargs):
        self.model, self.df_state, _ = init_df()

    def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
        output_file = output_file or self.get_output_file(audio_file, "_df")

        audio, _ = load_audio(audio_file, sr=self.df_state.sr())

        enhanced = enhance(self.model, self.df_state, audio)

        save_audio(output_file, enhanced, self.df_state.sr())

        return output_file


def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
    """Get an audio enhancer

    Args:
        enhancer_name (Literal["df"]): name of the audio enhancer

    Raises:
        ValueError: if the enhancer name is not recognised

    Returns:
        BaseEnhancer: audio enhancer
    """

    if enhancer_name == "df":
        import warnings

        warnings.filterwarnings(
            "ignore",
            message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
        )
        return DFEnhancer()
    else:
        raise ValueError(f"Unknown enhancer name: {enhancer_name}")