|
import os |
|
import yaml |
|
import logging |
|
import nltk |
|
import torch |
|
import torchaudio |
|
from torchaudio.transforms import SpeedPerturbation |
|
from APIs import WRITE_AUDIO, LOUDNESS_NORM |
|
|
|
from flask import Flask, request, jsonify |
|
import numpy as np |
|
|
|
def fade(audio_data, fade_duration=2, sr=32000): |
|
audio_duration = audio_data.shape[0] / sr |
|
|
|
|
|
if audio_duration >=8: |
|
|
|
pass |
|
else: |
|
fade_duration = audio_duration / 5 |
|
|
|
fade_sampels = int(sr * fade_duration) |
|
fade_in = np.linspace(0, 1, fade_sampels) |
|
fade_out = np.linspace(1, 0, fade_sampels) |
|
|
|
audio_data_fade_in = audio_data[:fade_sampels] * fade_in |
|
audio_data_fade_out = audio_data[-fade_sampels:] * fade_out |
|
|
|
audio_data_faded = np.concatenate((audio_data_fade_in, audio_data[len(fade_in):-len(fade_out)], audio_data_fade_out)) |
|
return audio_data_faded |
|
|
|
def get_service_port(): |
|
service_port = os.environ.get('WAVJOURNEY_SERVICE_PORT') |
|
return service_port |
|
|
|
with open('config.yaml', 'r') as file: |
|
config = yaml.safe_load(file) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
|
|
|
|
os.makedirs('services_logs', exist_ok=True) |
|
log_filename = 'services_logs/Wav-API.log' |
|
file_handler = logging.FileHandler(log_filename, mode='w') |
|
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) |
|
|
|
|
|
logging.getLogger('').addHandler(file_handler) |
|
|
|
|
|
""" |
|
Initialize the AudioCraft models here |
|
""" |
|
from audiocraft.models import AudioGen, MusicGen |
|
tta_model_size = config['AudioCraft']['tta_model_size'] |
|
tta_model = AudioGen.get_pretrained(f'facebook/audiogen-{tta_model_size}') |
|
logging.info(f'AudioGen ({tta_model_size}) is loaded ...') |
|
|
|
ttm_model_size = config['AudioCraft']['ttm_model_size'] |
|
ttm_model = MusicGen.get_pretrained(f'facebook/musicgen-{ttm_model_size}') |
|
logging.info(f'MusicGen ({ttm_model_size}) is loaded ...') |
|
|
|
|
|
""" |
|
Initialize the BarkModel here |
|
""" |
|
from transformers import BarkModel, AutoProcessor |
|
SPEED = float(config['Text-to-Speech']['speed']) |
|
speed_perturb = SpeedPerturbation(32000, [SPEED]) |
|
tts_model = BarkModel.from_pretrained("suno/bark") |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
tts_model = tts_model.to(device) |
|
tts_model = tts_model.to_bettertransformer() |
|
SAMPLE_RATE = tts_model.generation_config.sample_rate |
|
SEMANTIC_TEMPERATURE = 0.9 |
|
COARSE_TEMPERATURE = 0.5 |
|
FINE_TEMPERATURE = 0.5 |
|
processor = AutoProcessor.from_pretrained("suno/bark") |
|
logging.info('Bark model is loaded ...') |
|
|
|
|
|
""" |
|
Initialize the VoiceFixer model here |
|
""" |
|
from voicefixer import VoiceFixer |
|
vf = VoiceFixer() |
|
logging.info('VoiceFixer is loaded ...') |
|
|
|
|
|
""" |
|
Initalize the VoiceParser model here |
|
""" |
|
from VoiceParser.model import VoiceParser |
|
vp_device = config['Voice-Parser']['device'] |
|
vp = VoiceParser(device=vp_device) |
|
logging.info('VoiceParser is loaded ...') |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route('/generate_audio', methods=['POST']) |
|
def generate_audio(): |
|
|
|
data = request.json |
|
text = data['text'] |
|
length = float(data.get('length', 5.0)) |
|
volume = float(data.get('volume', -35)) |
|
output_wav = data.get('output_wav', 'out.wav') |
|
|
|
logging.info(f'TTA (AudioGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB') |
|
|
|
try: |
|
tta_model.set_generation_params(duration=length) |
|
wav = tta_model.generate([text]) |
|
wav = torchaudio.functional.resample(wav, orig_freq=16000, new_freq=32000) |
|
|
|
wav = wav.squeeze().cpu().detach().numpy() |
|
wav = fade(LOUDNESS_NORM(wav, volumn=volume)) |
|
WRITE_AUDIO(wav, name=output_wav) |
|
|
|
|
|
return jsonify({'message': f'Text-to-Audio generated successfully | {text}', 'file': output_wav}) |
|
|
|
except Exception as e: |
|
return jsonify({'API error': str(e)}), 500 |
|
|
|
|
|
@app.route('/generate_music', methods=['POST']) |
|
def generate_music(): |
|
|
|
data = request.json |
|
text = data['text'] |
|
length = float(data.get('length', 5.0)) |
|
volume = float(data.get('volume', -35)) |
|
output_wav = data.get('output_wav', 'out.wav') |
|
|
|
logging.info(f'TTM (MusicGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB') |
|
|
|
|
|
try: |
|
ttm_model.set_generation_params(duration=length) |
|
wav = ttm_model.generate([text]) |
|
wav = wav[0][0].cpu().detach().numpy() |
|
wav = fade(LOUDNESS_NORM(wav, volumn=volume)) |
|
WRITE_AUDIO(wav, name=output_wav) |
|
|
|
|
|
return jsonify({'message': f'Text-to-Music generated successfully | {text}', 'file': output_wav}) |
|
|
|
except Exception as e: |
|
|
|
return jsonify({'API error': str(e)}), 500 |
|
|
|
|
|
@app.route('/generate_speech', methods=['POST']) |
|
def generate_speech(): |
|
|
|
data = request.json |
|
text = data['text'] |
|
speaker_id = data['speaker_id'] |
|
speaker_npz = data['speaker_npz'] |
|
volume = float(data.get('volume', -35)) |
|
output_wav = data.get('output_wav', 'out.wav') |
|
|
|
logging.info(f'TTS (Bark): Speaker: {speaker_id}, Volume: {volume} dB, Prompt: {text}') |
|
|
|
try: |
|
|
|
text = text.replace('\n', ' ').strip() |
|
sentences = nltk.sent_tokenize(text) |
|
silence = torch.zeros(int(0.1 * SAMPLE_RATE), device=device).unsqueeze(0) |
|
|
|
pieces = [] |
|
for sentence in sentences: |
|
inputs = processor(sentence, voice_preset=speaker_npz).to(device) |
|
|
|
|
|
inputs['history_prompt']['coarse_prompt'] = inputs['history_prompt']['coarse_prompt'].transpose(0, 1).contiguous().transpose(0, 1) |
|
|
|
with torch.inference_mode(): |
|
|
|
output = tts_model.generate( |
|
**inputs, |
|
do_sample = True, |
|
semantic_temperature = SEMANTIC_TEMPERATURE, |
|
coarse_temperature = COARSE_TEMPERATURE, |
|
fine_temperature = FINE_TEMPERATURE |
|
) |
|
|
|
pieces += [output, silence] |
|
|
|
result_audio = torch.cat(pieces, dim=1) |
|
wav_tensor = result_audio.to(dtype=torch.float32).cpu() |
|
wav = torchaudio.functional.resample(wav_tensor, orig_freq=SAMPLE_RATE, new_freq=32000) |
|
wav = speed_perturb(wav.float())[0].squeeze(0) |
|
wav = wav.numpy() |
|
wav = LOUDNESS_NORM(wav, volumn=volume) |
|
WRITE_AUDIO(wav, name=output_wav) |
|
|
|
|
|
return jsonify({'message': f'Text-to-Speech generated successfully | {speaker_id}: {text}', 'file': output_wav}) |
|
|
|
except Exception as e: |
|
|
|
return jsonify({'API error': str(e)}), 500 |
|
|
|
|
|
@app.route('/fix_audio', methods=['POST']) |
|
def fix_audio(): |
|
|
|
data = request.json |
|
processfile = data['processfile'] |
|
|
|
logging.info(f'Fixing {processfile} ...') |
|
|
|
try: |
|
vf.restore(input=processfile, output=processfile, cuda=True, mode=0) |
|
|
|
|
|
return jsonify({'message': 'Speech restored successfully', 'file': processfile}) |
|
|
|
except Exception as e: |
|
|
|
return jsonify({'API error': str(e)}), 500 |
|
|
|
|
|
@app.route('/parse_voice', methods=['POST']) |
|
def parse_voice(): |
|
|
|
data = request.json |
|
wav_path = data['wav_path'] |
|
out_dir = data['out_dir'] |
|
|
|
logging.info(f'Parsing {wav_path} ...') |
|
|
|
try: |
|
vp.extract_acoustic_embed(wav_path, out_dir) |
|
|
|
|
|
return jsonify({'message': f'Sucessfully parsed {wav_path}'}) |
|
|
|
except Exception as e: |
|
|
|
return jsonify({'API error': str(e)}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
service_port = get_service_port() |
|
|
|
app.run(debug=False, threaded=False, port=7860) |
|
|