File size: 4,468 Bytes
d50bd1e
 
 
 
 
8a2882e
d50bd1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2882e
 
d50bd1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2882e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50bd1e
 
 
 
8a2882e
 
 
d50bd1e
8a2882e
d50bd1e
8a2882e
 
d50bd1e
 
 
 
8a2882e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50bd1e
 
8a2882e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import warnings

warnings.simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import io
import torch
import numpy as np
from audiocraft.models import musicgen
from scipy.io.wavfile import write as wav_write

try:
    from logger import logging
except:
    import logging


class GenerateAudio:
    def __init__(self, model="musicgen-stereo-small"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = self.get_model_name(model)
        self.model = self.get_model(self.model_name, self.device)
        self.generated_audio = None
        self.sampling_rate = None

    @staticmethod
    def get_model(model, device):
        try:
            model = musicgen.MusicGen.get_pretrained(model, device=device)
            logging.info(f"Loaded model: {model}")
            return model
        except Exception as e:
            logging.error(f"Failed to load model: {e}")
            raise ValueError(f"Failed to load model: {e}")
            return

    @staticmethod
    def get_model_name(model_name):
        if model_name.startswith("facebook/"):
            return model_name
        return f"facebook/{model_name}"
    
    @staticmethod
    def duration_sanity_check(duration):
        if duration < 1:
            logging.warning("Duration is less than 1 second. Setting duration to 1 second.")
            return 1
        elif duration > 30:
            logging.warning("Duration is greater than 30 seconds. Setting duration to 30 seconds.")
            return 30
        return duration

    @staticmethod
    def prompts_sanity_check(prompts):
        if isinstance(prompts, str):
            prompts = [prompts]
        elif not isinstance(prompts, list):
            raise ValueError("Prompts should be a string or a list of strings.")
        else:
            for prompt in prompts:
                if not isinstance(prompt, str):
                    raise ValueError("Prompts should be a string or a list of strings.")
            if len(prompts) > 8: # Too many prompts will cause OOM error
                raise ValueError("Maximum number of prompts allowed is 8.")
        return prompts
    

    def generate_audio(self, prompts, duration=10):
        duration = self.duration_sanity_check(duration)
        prompts = self.prompts_sanity_check(prompts)

        try:
            self.model.set_generation_params(duration=duration)
            result = self.model.generate(prompts, progress=False)
            self.result = result.cpu().numpy().T
            self.result = self.result.transpose((2, 0, 1))
            self.sampling_rate = self.model.sample_rate
            logging.info(
                f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
            )
            print(f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz")
            return self.sampling_rate, self.result
        except Exception as e:
            logging.error(f"Failed to generate audio: {e}")
            raise ValueError(f"Failed to generate audio: {e}")

    def save_audio(self, audio_dir="generated_audio"):
        if self.result is None:
            raise ValueError("Audio is not generated yet.")
        if self.sampling_rate is None:
            raise ValueError("Sampling rate is not available.")

        paths = []
        os.makedirs(audio_dir, exist_ok=True)
        for i, audio in enumerate(self.result):
            path = os.path.join(audio_dir, f"audio_{i}.wav")
            wav_write(path, self.sampling_rate, audio)
            paths.append(path)
        return paths

    def get_audio_buffer(self):
        if self.result is None:
            raise ValueError("Audio is not generated yet.")
        if self.sampling_rate is None:
            raise ValueError("Sampling rate is not available.")

        buffers = []
        for audio in self.result:
            buffer = io.BytesIO()
            wav_write(buffer, self.sampling_rate, audio)
            buffer.seek(0)
            buffers.append(buffer)
        return buffers

if __name__ == "__main__":
    audio_gen = GenerateAudio()
    sample_rate, result = audio_gen.generate_audio(["A piano playing a jazz melody", "A guitar playing a rock riff", "A LoFi music for coding"], duration=10)
    paths = audio_gen.save_audio()
    print(f"Saved audio to: {paths}")
    buffers = audio_gen.get_audio_buffer()
    print(f"Audio buffers: {buffers}")