Update app.py
Browse files
app.py
CHANGED
@@ -1,86 +1,256 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
2 |
import torch
|
|
|
|
|
|
|
|
|
3 |
from flask import Flask, request, jsonify
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
9 |
-
# Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
|
10 |
-
if os.path.exists("fine_tuned_checkpoint"):
|
11 |
-
model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
|
12 |
-
else:
|
13 |
-
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
def fine_tune_model(chat_history):
|
19 |
-
# Prepare training data for fine-tuning
|
20 |
-
input_texts = [item["message"] for item in chat_history]
|
21 |
-
with open("train.txt", "w") as f:
|
22 |
-
f.write("\n".join(input_texts))
|
23 |
-
|
24 |
-
# Load the dataset and create data collator
|
25 |
-
dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
|
26 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
27 |
-
|
28 |
-
# Fine-tune the model
|
29 |
-
trainer = Trainer(model=model, data_collator=data_collator)
|
30 |
-
trainer.train("./training_directory")
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
user_input = request_data["user_input"]
|
39 |
-
chat_history = request_data.get("chat_history", [])
|
40 |
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
# Append bot message to the chat history
|
48 |
-
chat_history.append({"role": "bot", "message": response})
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
-
# Use the last `max_history` messages from the chat history
|
66 |
-
inputs = [item["message"] for item in chat_history[-max_history:]]
|
67 |
-
input_text = "\n".join(inputs)
|
68 |
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
with torch.no_grad():
|
74 |
-
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
78 |
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
@app.route("/")
|
82 |
-
def index():
|
83 |
-
return jsonify({"status" : True})
|
84 |
|
85 |
-
if __name__ ==
|
86 |
-
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import logging
|
4 |
+
import nltk
|
5 |
import torch
|
6 |
+
import torchaudio
|
7 |
+
from torchaudio.transforms import SpeedPerturbation
|
8 |
+
from APIs import WRITE_AUDIO, LOUDNESS_NORM
|
9 |
+
# from utils import fade, get_service_port
|
10 |
from flask import Flask, request, jsonify
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
def fade(audio_data, fade_duration=2, sr=32000):
|
14 |
+
audio_duration = audio_data.shape[0] / sr
|
15 |
+
|
16 |
+
# automated choose fade duration
|
17 |
+
if audio_duration >=8:
|
18 |
+
# keep fade_duration 2
|
19 |
+
pass
|
20 |
+
else:
|
21 |
+
fade_duration = audio_duration / 5
|
22 |
+
|
23 |
+
fade_sampels = int(sr * fade_duration)
|
24 |
+
fade_in = np.linspace(0, 1, fade_sampels)
|
25 |
+
fade_out = np.linspace(1, 0, fade_sampels)
|
26 |
+
|
27 |
+
audio_data_fade_in = audio_data[:fade_sampels] * fade_in
|
28 |
+
audio_data_fade_out = audio_data[-fade_sampels:] * fade_out
|
29 |
+
|
30 |
+
audio_data_faded = np.concatenate((audio_data_fade_in, audio_data[len(fade_in):-len(fade_out)], audio_data_fade_out))
|
31 |
+
return audio_data_faded
|
32 |
+
|
33 |
+
def get_service_port():
|
34 |
+
service_port = os.environ.get('WAVJOURNEY_SERVICE_PORT')
|
35 |
+
return service_port
|
36 |
+
|
37 |
+
with open('config.yaml', 'r') as file:
|
38 |
+
config = yaml.safe_load(file)
|
39 |
+
|
40 |
+
# Configure the logging format and level
|
41 |
+
logging.basicConfig(
|
42 |
+
level=logging.INFO,
|
43 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
44 |
+
)
|
45 |
+
|
46 |
+
# Create a FileHandler for the log file
|
47 |
+
os.makedirs('services_logs', exist_ok=True)
|
48 |
+
log_filename = 'services_logs/Wav-API.log'
|
49 |
+
file_handler = logging.FileHandler(log_filename, mode='w')
|
50 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
51 |
+
|
52 |
+
# Add the FileHandler to the root logger
|
53 |
+
logging.getLogger('').addHandler(file_handler)
|
54 |
+
|
55 |
+
|
56 |
+
"""
|
57 |
+
Initialize the AudioCraft models here
|
58 |
+
"""
|
59 |
+
from audiocraft.models import AudioGen, MusicGen
|
60 |
+
tta_model_size = config['AudioCraft']['tta_model_size']
|
61 |
+
tta_model = AudioGen.get_pretrained(f'facebook/audiogen-{tta_model_size}')
|
62 |
+
logging.info(f'AudioGen ({tta_model_size}) is loaded ...')
|
63 |
+
|
64 |
+
ttm_model_size = config['AudioCraft']['ttm_model_size']
|
65 |
+
ttm_model = MusicGen.get_pretrained(f'facebook/musicgen-{ttm_model_size}')
|
66 |
+
logging.info(f'MusicGen ({ttm_model_size}) is loaded ...')
|
67 |
+
|
68 |
+
|
69 |
+
"""
|
70 |
+
Initialize the BarkModel here
|
71 |
+
"""
|
72 |
+
from transformers import BarkModel, AutoProcessor
|
73 |
+
SPEED = float(config['Text-to-Speech']['speed'])
|
74 |
+
speed_perturb = SpeedPerturbation(32000, [SPEED])
|
75 |
+
tts_model = BarkModel.from_pretrained("suno/bark")
|
76 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
77 |
+
tts_model = tts_model.to(device)
|
78 |
+
tts_model = tts_model.to_bettertransformer() # Flash attention
|
79 |
+
SAMPLE_RATE = tts_model.generation_config.sample_rate
|
80 |
+
SEMANTIC_TEMPERATURE = 0.9
|
81 |
+
COARSE_TEMPERATURE = 0.5
|
82 |
+
FINE_TEMPERATURE = 0.5
|
83 |
+
processor = AutoProcessor.from_pretrained("suno/bark")
|
84 |
+
logging.info('Bark model is loaded ...')
|
85 |
+
|
86 |
+
|
87 |
+
"""
|
88 |
+
Initialize the VoiceFixer model here
|
89 |
+
"""
|
90 |
+
from voicefixer import VoiceFixer
|
91 |
+
vf = VoiceFixer()
|
92 |
+
logging.info('VoiceFixer is loaded ...')
|
93 |
+
|
94 |
+
|
95 |
+
"""
|
96 |
+
Initalize the VoiceParser model here
|
97 |
+
"""
|
98 |
+
from VoiceParser.model import VoiceParser
|
99 |
+
vp_device = config['Voice-Parser']['device']
|
100 |
+
vp = VoiceParser(device=vp_device)
|
101 |
+
logging.info('VoiceParser is loaded ...')
|
102 |
+
|
103 |
|
104 |
app = Flask(__name__)
|
105 |
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
@app.route('/generate_audio', methods=['POST'])
|
108 |
+
def generate_audio():
|
109 |
+
# Receive the text from the POST request
|
110 |
+
data = request.json
|
111 |
+
text = data['text']
|
112 |
+
length = float(data.get('length', 5.0))
|
113 |
+
volume = float(data.get('volume', -35))
|
114 |
+
output_wav = data.get('output_wav', 'out.wav')
|
115 |
|
116 |
+
logging.info(f'TTA (AudioGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
try:
|
119 |
+
tta_model.set_generation_params(duration=length)
|
120 |
+
wav = tta_model.generate([text])
|
121 |
+
wav = torchaudio.functional.resample(wav, orig_freq=16000, new_freq=32000)
|
122 |
|
123 |
+
wav = wav.squeeze().cpu().detach().numpy()
|
124 |
+
wav = fade(LOUDNESS_NORM(wav, volumn=volume))
|
125 |
+
WRITE_AUDIO(wav, name=output_wav)
|
|
|
|
|
126 |
|
127 |
+
# Return success message and the filename of the generated audio
|
128 |
+
return jsonify({'message': f'Text-to-Audio generated successfully | {text}', 'file': output_wav})
|
129 |
|
130 |
+
except Exception as e:
|
131 |
+
return jsonify({'API error': str(e)}), 500
|
132 |
|
|
|
|
|
133 |
|
134 |
+
@app.route('/generate_music', methods=['POST'])
|
135 |
+
def generate_music():
|
136 |
+
# Receive the text from the POST request
|
137 |
+
data = request.json
|
138 |
+
text = data['text']
|
139 |
+
length = float(data.get('length', 5.0))
|
140 |
+
volume = float(data.get('volume', -35))
|
141 |
+
output_wav = data.get('output_wav', 'out.wav')
|
142 |
|
143 |
+
logging.info(f'TTM (MusicGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
|
144 |
+
|
145 |
+
|
146 |
+
try:
|
147 |
+
ttm_model.set_generation_params(duration=length)
|
148 |
+
wav = ttm_model.generate([text])
|
149 |
+
wav = wav[0][0].cpu().detach().numpy()
|
150 |
+
wav = fade(LOUDNESS_NORM(wav, volumn=volume))
|
151 |
+
WRITE_AUDIO(wav, name=output_wav)
|
152 |
+
|
153 |
+
# Return success message and the filename of the generated audio
|
154 |
+
return jsonify({'message': f'Text-to-Music generated successfully | {text}', 'file': output_wav})
|
155 |
+
|
156 |
+
except Exception as e:
|
157 |
+
# Return error message if something goes wrong
|
158 |
+
return jsonify({'API error': str(e)}), 500
|
159 |
+
|
160 |
+
|
161 |
+
@app.route('/generate_speech', methods=['POST'])
|
162 |
+
def generate_speech():
|
163 |
+
# Receive the text from the POST request
|
164 |
+
data = request.json
|
165 |
+
text = data['text']
|
166 |
+
speaker_id = data['speaker_id']
|
167 |
+
speaker_npz = data['speaker_npz']
|
168 |
+
volume = float(data.get('volume', -35))
|
169 |
+
output_wav = data.get('output_wav', 'out.wav')
|
170 |
|
171 |
+
logging.info(f'TTS (Bark): Speaker: {speaker_id}, Volume: {volume} dB, Prompt: {text}')
|
172 |
+
|
173 |
+
try:
|
174 |
+
# Generate audio using the global pipe object
|
175 |
+
text = text.replace('\n', ' ').strip()
|
176 |
+
sentences = nltk.sent_tokenize(text)
|
177 |
+
silence = torch.zeros(int(0.1 * SAMPLE_RATE), device=device).unsqueeze(0) # 0.1 second of silence
|
178 |
+
|
179 |
+
pieces = []
|
180 |
+
for sentence in sentences:
|
181 |
+
inputs = processor(sentence, voice_preset=speaker_npz).to(device)
|
182 |
+
# NOTE: you must run the line below, otherwise you will see the runtime error
|
183 |
+
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
184 |
+
inputs['history_prompt']['coarse_prompt'] = inputs['history_prompt']['coarse_prompt'].transpose(0, 1).contiguous().transpose(0, 1)
|
185 |
+
|
186 |
+
with torch.inference_mode():
|
187 |
+
# TODO: min_eos_p?
|
188 |
+
output = tts_model.generate(
|
189 |
+
**inputs,
|
190 |
+
do_sample = True,
|
191 |
+
semantic_temperature = SEMANTIC_TEMPERATURE,
|
192 |
+
coarse_temperature = COARSE_TEMPERATURE,
|
193 |
+
fine_temperature = FINE_TEMPERATURE
|
194 |
+
)
|
195 |
+
|
196 |
+
pieces += [output, silence]
|
197 |
+
|
198 |
+
result_audio = torch.cat(pieces, dim=1)
|
199 |
+
wav_tensor = result_audio.to(dtype=torch.float32).cpu()
|
200 |
+
wav = torchaudio.functional.resample(wav_tensor, orig_freq=SAMPLE_RATE, new_freq=32000)
|
201 |
+
wav = speed_perturb(wav.float())[0].squeeze(0)
|
202 |
+
wav = wav.numpy()
|
203 |
+
wav = LOUDNESS_NORM(wav, volumn=volume)
|
204 |
+
WRITE_AUDIO(wav, name=output_wav)
|
205 |
+
|
206 |
+
# Return success message and the filename of the generated audio
|
207 |
+
return jsonify({'message': f'Text-to-Speech generated successfully | {speaker_id}: {text}', 'file': output_wav})
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
# Return error message if something goes wrong
|
211 |
+
return jsonify({'API error': str(e)}), 500
|
212 |
+
|
213 |
+
|
214 |
+
@app.route('/fix_audio', methods=['POST'])
|
215 |
+
def fix_audio():
|
216 |
+
# Receive the text from the POST request
|
217 |
+
data = request.json
|
218 |
+
processfile = data['processfile']
|
219 |
+
|
220 |
+
logging.info(f'Fixing {processfile} ...')
|
221 |
+
|
222 |
+
try:
|
223 |
+
vf.restore(input=processfile, output=processfile, cuda=True, mode=0)
|
224 |
+
|
225 |
+
# Return success message and the filename of the generated audio
|
226 |
+
return jsonify({'message': 'Speech restored successfully', 'file': processfile})
|
227 |
|
228 |
+
except Exception as e:
|
229 |
+
# Return error message if something goes wrong
|
230 |
+
return jsonify({'API error': str(e)}), 500
|
231 |
|
|
|
|
|
|
|
232 |
|
233 |
+
@app.route('/parse_voice', methods=['POST'])
|
234 |
+
def parse_voice():
|
235 |
+
# Receive the text from the POST request
|
236 |
+
data = request.json
|
237 |
+
wav_path = data['wav_path']
|
238 |
+
out_dir = data['out_dir']
|
239 |
|
240 |
+
logging.info(f'Parsing {wav_path} ...')
|
|
|
|
|
241 |
|
242 |
+
try:
|
243 |
+
vp.extract_acoustic_embed(wav_path, out_dir)
|
244 |
+
|
245 |
+
# Return success message and the filename of the generated audio
|
246 |
+
return jsonify({'message': f'Sucessfully parsed {wav_path}'})
|
247 |
|
248 |
+
except Exception as e:
|
249 |
+
# Return error message if something goes wrong
|
250 |
+
return jsonify({'API error': str(e)}), 500
|
251 |
|
|
|
|
|
|
|
252 |
|
253 |
+
if __name__ == '__main__':
|
254 |
+
service_port = get_service_port()
|
255 |
+
# We disable multithreading to force services to process one request at a time and avoid CUDA OOM
|
256 |
+
app.run(debug=False, threaded=False, port=7860)
|