dantts / src /main.py
imseldrith's picture
Update src/main.py
c410916 verified
raw
history blame
3.26 kB
from fastapi import FastAPI
import schemas
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from functions import *
import base64
import os
import traceback
from bark import SAMPLE_RATE, generate_audio, preload_models
import soundfile as sf
import wave
import numpy as np
import nltk
# fastapi port
server_port = 7860
# Preload model
preload_models()
app = FastAPI(docs_url=None, redoc_url=None)
# Set allowed access domain names
origins = ["*"] # set to "*" means all.
def concatenate_wavs(wav_files, output_file, silence_duration=0.3):
wavs = [wave.open(f, 'rb') for f in wav_files]
sampwidth = wavs[0].getsampwidth()
framerate = wavs[0].getframerate()
nchannels = wavs[0].getnchannels()
samples = [wav.readframes(wav.getnframes()) for wav in wavs]
total_frames = sum(len(s) for s in samples) + int(silence_duration * framerate * nchannels * sampwidth)
output = wave.open(output_file, 'wb')
output.setparams((nchannels, sampwidth, framerate, total_frames, 'NONE', 'Uncompressed'))
for s in samples:
output.writeframes(s)
silence_frame = np.zeros((int(silence_duration * framerate), 2)).astype(np.int16).tobytes()
for i in range(int(nchannels / 2)):
output.writeframes(silence_frame)
output.close()
# Set cross domain parameter transfer
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Set allowed origins sources
allow_credentials=True,
allow_methods=["*"], # Set up HTTP methods that allow cross domain access, such as get, post, put, etc.
allow_headers=["*"]) # Allowing cross domain headers can be used to identify sources and other functions.
@app.post("/tts_bark/")
async def tts_bark(item: schemas.generate_web):
time_start = time.time()
text = item.text
print(f"{text=}")
try:
sentences = nltk.sent_tokenize(text)
idx = 1
wavs = []
for s in sentences:
audio_array = generate_audio(s, history_prompt="en_speaker_8", text_temp=0.6, waveform_temp=0.6)
fname = f"tmp-{idx}.wav"
sf.write(fname, audio_array, SAMPLE_RATE)
idx += 1
wavs.append(fname)
file_name_pre = f"out-{time.time()}"
file_name_wav = file_name_pre + ".wav"
file_name_ogg = file_name_pre + ".ogg"
concatenate_wavs(wavs, file_name_wav)
# convert to OGG
os.system("ffmpeg -i " + file_name_wav + " -c:a libopus -b:a 64k -y " + file_name_ogg)
with open(file_name_ogg, "rb") as f:
audio_content = f.read()
base64_audio = base64.b64encode(audio_content).decode("utf-8")
res = {"file_base64": base64_audio,
"audio_text": text,
"file_name": file_name_ogg,
}
print_log(item, res, time_start)
os.remove(file_name_wav)
os.remove(file_name_ogg)
return res
except Exception as err:
res = {"code": 9, "msg": "api error", "err": str(err), "traceback": traceback.format_exc()}
print_log(item, res, time_start)
return res
if __name__ == '__main__':
print_env(server_port)
uvicorn.run(app="main:app", host="0.0.0.0", port=server_port, reload=False)