ZeroTwo3 commited on
Commit
d00003a
·
1 Parent(s): 389b910

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -63
app.py CHANGED
@@ -1,86 +1,256 @@
1
- import json
 
 
 
2
  import torch
 
 
 
 
3
  from flask import Flask, request, jsonify
4
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer
5
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  app = Flask(__name__)
8
 
9
- # Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
10
- if os.path.exists("fine_tuned_checkpoint"):
11
- model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
12
- else:
13
- model = GPT2LMHeadModel.from_pretrained("gpt2")
14
 
15
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
 
 
 
 
 
 
16
 
17
- # Function to fine-tune the model
18
- def fine_tune_model(chat_history):
19
- # Prepare training data for fine-tuning
20
- input_texts = [item["message"] for item in chat_history]
21
- with open("train.txt", "w") as f:
22
- f.write("\n".join(input_texts))
23
-
24
- # Load the dataset and create data collator
25
- dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
26
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
27
-
28
- # Fine-tune the model
29
- trainer = Trainer(model=model, data_collator=data_collator)
30
- trainer.train("./training_directory")
31
 
32
- # Save the fine-tuned model
33
- model.save_pretrained("fine_tuned_model")
 
 
34
 
35
- @app.route("/chat", methods=["POST"])
36
- def chat_with_model():
37
- request_data = request.get_json()
38
- user_input = request_data["user_input"]
39
- chat_history = request_data.get("chat_history", [])
40
 
41
- # Append user message to the chat history
42
- chat_history.append({"role": "user", "message": user_input})
43
 
44
- # Generate response
45
- response = generate_response(user_input, chat_history)
46
 
47
- # Append bot message to the chat history
48
- chat_history.append({"role": "bot", "message": response})
49
 
50
- return jsonify({"bot_response": response, "chat_history": chat_history})
 
 
 
 
 
 
 
51
 
52
- @app.route("/train", methods=["POST"])
53
- def train_model():
54
- chat_history = request.json["data"]
55
-
56
- # Fine-tune the model with the provided data
57
- fine_tune_model(chat_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- return "Model trained and updated successfully."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def generate_response(user_input, chat_history):
62
- # Set the maximum number of previous messages to consider
63
- max_history = 3
64
 
65
- # Use the last `max_history` messages from the chat history
66
- inputs = [item["message"] for item in chat_history[-max_history:]]
67
- input_text = "\n".join(inputs)
68
 
69
- # Tokenize the input text
70
- input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True)
 
 
 
 
71
 
72
- # Generate response
73
- with torch.no_grad():
74
- output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
75
 
76
- # Decode response and extract bot message
77
- bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
78
 
79
- return bot_response
 
 
80
 
81
- @app.route("/")
82
- def index():
83
- return jsonify({"status" : True})
84
 
85
- if __name__ == "__main__":
86
- app.run(host="0.0.0.0", port=7860)
 
 
 
1
+ import os
2
+ import yaml
3
+ import logging
4
+ import nltk
5
  import torch
6
+ import torchaudio
7
+ from torchaudio.transforms import SpeedPerturbation
8
+ from APIs import WRITE_AUDIO, LOUDNESS_NORM
9
+ # from utils import fade, get_service_port
10
  from flask import Flask, request, jsonify
11
+ import numpy as np
12
+
13
+ def fade(audio_data, fade_duration=2, sr=32000):
14
+ audio_duration = audio_data.shape[0] / sr
15
+
16
+ # automated choose fade duration
17
+ if audio_duration >=8:
18
+ # keep fade_duration 2
19
+ pass
20
+ else:
21
+ fade_duration = audio_duration / 5
22
+
23
+ fade_sampels = int(sr * fade_duration)
24
+ fade_in = np.linspace(0, 1, fade_sampels)
25
+ fade_out = np.linspace(1, 0, fade_sampels)
26
+
27
+ audio_data_fade_in = audio_data[:fade_sampels] * fade_in
28
+ audio_data_fade_out = audio_data[-fade_sampels:] * fade_out
29
+
30
+ audio_data_faded = np.concatenate((audio_data_fade_in, audio_data[len(fade_in):-len(fade_out)], audio_data_fade_out))
31
+ return audio_data_faded
32
+
33
+ def get_service_port():
34
+ service_port = os.environ.get('WAVJOURNEY_SERVICE_PORT')
35
+ return service_port
36
+
37
+ with open('config.yaml', 'r') as file:
38
+ config = yaml.safe_load(file)
39
+
40
+ # Configure the logging format and level
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(levelname)s - %(message)s'
44
+ )
45
+
46
+ # Create a FileHandler for the log file
47
+ os.makedirs('services_logs', exist_ok=True)
48
+ log_filename = 'services_logs/Wav-API.log'
49
+ file_handler = logging.FileHandler(log_filename, mode='w')
50
+ file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
51
+
52
+ # Add the FileHandler to the root logger
53
+ logging.getLogger('').addHandler(file_handler)
54
+
55
+
56
+ """
57
+ Initialize the AudioCraft models here
58
+ """
59
+ from audiocraft.models import AudioGen, MusicGen
60
+ tta_model_size = config['AudioCraft']['tta_model_size']
61
+ tta_model = AudioGen.get_pretrained(f'facebook/audiogen-{tta_model_size}')
62
+ logging.info(f'AudioGen ({tta_model_size}) is loaded ...')
63
+
64
+ ttm_model_size = config['AudioCraft']['ttm_model_size']
65
+ ttm_model = MusicGen.get_pretrained(f'facebook/musicgen-{ttm_model_size}')
66
+ logging.info(f'MusicGen ({ttm_model_size}) is loaded ...')
67
+
68
+
69
+ """
70
+ Initialize the BarkModel here
71
+ """
72
+ from transformers import BarkModel, AutoProcessor
73
+ SPEED = float(config['Text-to-Speech']['speed'])
74
+ speed_perturb = SpeedPerturbation(32000, [SPEED])
75
+ tts_model = BarkModel.from_pretrained("suno/bark")
76
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
77
+ tts_model = tts_model.to(device)
78
+ tts_model = tts_model.to_bettertransformer() # Flash attention
79
+ SAMPLE_RATE = tts_model.generation_config.sample_rate
80
+ SEMANTIC_TEMPERATURE = 0.9
81
+ COARSE_TEMPERATURE = 0.5
82
+ FINE_TEMPERATURE = 0.5
83
+ processor = AutoProcessor.from_pretrained("suno/bark")
84
+ logging.info('Bark model is loaded ...')
85
+
86
+
87
+ """
88
+ Initialize the VoiceFixer model here
89
+ """
90
+ from voicefixer import VoiceFixer
91
+ vf = VoiceFixer()
92
+ logging.info('VoiceFixer is loaded ...')
93
+
94
+
95
+ """
96
+ Initalize the VoiceParser model here
97
+ """
98
+ from VoiceParser.model import VoiceParser
99
+ vp_device = config['Voice-Parser']['device']
100
+ vp = VoiceParser(device=vp_device)
101
+ logging.info('VoiceParser is loaded ...')
102
+
103
 
104
  app = Flask(__name__)
105
 
 
 
 
 
 
106
 
107
+ @app.route('/generate_audio', methods=['POST'])
108
+ def generate_audio():
109
+ # Receive the text from the POST request
110
+ data = request.json
111
+ text = data['text']
112
+ length = float(data.get('length', 5.0))
113
+ volume = float(data.get('volume', -35))
114
+ output_wav = data.get('output_wav', 'out.wav')
115
 
116
+ logging.info(f'TTA (AudioGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ try:
119
+ tta_model.set_generation_params(duration=length)
120
+ wav = tta_model.generate([text])
121
+ wav = torchaudio.functional.resample(wav, orig_freq=16000, new_freq=32000)
122
 
123
+ wav = wav.squeeze().cpu().detach().numpy()
124
+ wav = fade(LOUDNESS_NORM(wav, volumn=volume))
125
+ WRITE_AUDIO(wav, name=output_wav)
 
 
126
 
127
+ # Return success message and the filename of the generated audio
128
+ return jsonify({'message': f'Text-to-Audio generated successfully | {text}', 'file': output_wav})
129
 
130
+ except Exception as e:
131
+ return jsonify({'API error': str(e)}), 500
132
 
 
 
133
 
134
+ @app.route('/generate_music', methods=['POST'])
135
+ def generate_music():
136
+ # Receive the text from the POST request
137
+ data = request.json
138
+ text = data['text']
139
+ length = float(data.get('length', 5.0))
140
+ volume = float(data.get('volume', -35))
141
+ output_wav = data.get('output_wav', 'out.wav')
142
 
143
+ logging.info(f'TTM (MusicGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
144
+
145
+
146
+ try:
147
+ ttm_model.set_generation_params(duration=length)
148
+ wav = ttm_model.generate([text])
149
+ wav = wav[0][0].cpu().detach().numpy()
150
+ wav = fade(LOUDNESS_NORM(wav, volumn=volume))
151
+ WRITE_AUDIO(wav, name=output_wav)
152
+
153
+ # Return success message and the filename of the generated audio
154
+ return jsonify({'message': f'Text-to-Music generated successfully | {text}', 'file': output_wav})
155
+
156
+ except Exception as e:
157
+ # Return error message if something goes wrong
158
+ return jsonify({'API error': str(e)}), 500
159
+
160
+
161
+ @app.route('/generate_speech', methods=['POST'])
162
+ def generate_speech():
163
+ # Receive the text from the POST request
164
+ data = request.json
165
+ text = data['text']
166
+ speaker_id = data['speaker_id']
167
+ speaker_npz = data['speaker_npz']
168
+ volume = float(data.get('volume', -35))
169
+ output_wav = data.get('output_wav', 'out.wav')
170
 
171
+ logging.info(f'TTS (Bark): Speaker: {speaker_id}, Volume: {volume} dB, Prompt: {text}')
172
+
173
+ try:
174
+ # Generate audio using the global pipe object
175
+ text = text.replace('\n', ' ').strip()
176
+ sentences = nltk.sent_tokenize(text)
177
+ silence = torch.zeros(int(0.1 * SAMPLE_RATE), device=device).unsqueeze(0) # 0.1 second of silence
178
+
179
+ pieces = []
180
+ for sentence in sentences:
181
+ inputs = processor(sentence, voice_preset=speaker_npz).to(device)
182
+ # NOTE: you must run the line below, otherwise you will see the runtime error
183
+ # RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
184
+ inputs['history_prompt']['coarse_prompt'] = inputs['history_prompt']['coarse_prompt'].transpose(0, 1).contiguous().transpose(0, 1)
185
+
186
+ with torch.inference_mode():
187
+ # TODO: min_eos_p?
188
+ output = tts_model.generate(
189
+ **inputs,
190
+ do_sample = True,
191
+ semantic_temperature = SEMANTIC_TEMPERATURE,
192
+ coarse_temperature = COARSE_TEMPERATURE,
193
+ fine_temperature = FINE_TEMPERATURE
194
+ )
195
+
196
+ pieces += [output, silence]
197
+
198
+ result_audio = torch.cat(pieces, dim=1)
199
+ wav_tensor = result_audio.to(dtype=torch.float32).cpu()
200
+ wav = torchaudio.functional.resample(wav_tensor, orig_freq=SAMPLE_RATE, new_freq=32000)
201
+ wav = speed_perturb(wav.float())[0].squeeze(0)
202
+ wav = wav.numpy()
203
+ wav = LOUDNESS_NORM(wav, volumn=volume)
204
+ WRITE_AUDIO(wav, name=output_wav)
205
+
206
+ # Return success message and the filename of the generated audio
207
+ return jsonify({'message': f'Text-to-Speech generated successfully | {speaker_id}: {text}', 'file': output_wav})
208
+
209
+ except Exception as e:
210
+ # Return error message if something goes wrong
211
+ return jsonify({'API error': str(e)}), 500
212
+
213
+
214
+ @app.route('/fix_audio', methods=['POST'])
215
+ def fix_audio():
216
+ # Receive the text from the POST request
217
+ data = request.json
218
+ processfile = data['processfile']
219
+
220
+ logging.info(f'Fixing {processfile} ...')
221
+
222
+ try:
223
+ vf.restore(input=processfile, output=processfile, cuda=True, mode=0)
224
+
225
+ # Return success message and the filename of the generated audio
226
+ return jsonify({'message': 'Speech restored successfully', 'file': processfile})
227
 
228
+ except Exception as e:
229
+ # Return error message if something goes wrong
230
+ return jsonify({'API error': str(e)}), 500
231
 
 
 
 
232
 
233
+ @app.route('/parse_voice', methods=['POST'])
234
+ def parse_voice():
235
+ # Receive the text from the POST request
236
+ data = request.json
237
+ wav_path = data['wav_path']
238
+ out_dir = data['out_dir']
239
 
240
+ logging.info(f'Parsing {wav_path} ...')
 
 
241
 
242
+ try:
243
+ vp.extract_acoustic_embed(wav_path, out_dir)
244
+
245
+ # Return success message and the filename of the generated audio
246
+ return jsonify({'message': f'Sucessfully parsed {wav_path}'})
247
 
248
+ except Exception as e:
249
+ # Return error message if something goes wrong
250
+ return jsonify({'API error': str(e)}), 500
251
 
 
 
 
252
 
253
+ if __name__ == '__main__':
254
+ service_port = get_service_port()
255
+ # We disable multithreading to force services to process one request at a time and avoid CUDA OOM
256
+ app.run(debug=False, threaded=False, port=7860)