File size: 3,257 Bytes
a07ed46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c410916
a07ed46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from fastapi import FastAPI
import schemas
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from functions import *
import base64
import os
import traceback

from bark import SAMPLE_RATE, generate_audio, preload_models
import soundfile as sf
import wave
import numpy as np
import nltk

# fastapi port
server_port = 7860

# Preload model
preload_models()

app = FastAPI(docs_url=None, redoc_url=None)

# Set allowed access domain names
origins = ["*"]  # set to "*" means all.


def concatenate_wavs(wav_files, output_file, silence_duration=0.3):
    wavs = [wave.open(f, 'rb') for f in wav_files]
    sampwidth = wavs[0].getsampwidth()
    framerate = wavs[0].getframerate()
    nchannels = wavs[0].getnchannels()

    samples = [wav.readframes(wav.getnframes()) for wav in wavs]
    total_frames = sum(len(s) for s in samples) + int(silence_duration * framerate * nchannels * sampwidth)

    output = wave.open(output_file, 'wb')
    output.setparams((nchannels, sampwidth, framerate, total_frames, 'NONE', 'Uncompressed'))

    for s in samples:
        output.writeframes(s)
        silence_frame = np.zeros((int(silence_duration * framerate), 2)).astype(np.int16).tobytes()
        for i in range(int(nchannels / 2)):
            output.writeframes(silence_frame)

    output.close()


# Set cross domain parameter transfer
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # Set allowed origins sources
    allow_credentials=True,
    allow_methods=["*"],  # Set up HTTP methods that allow cross domain access, such as get, post, put, etc.
    allow_headers=["*"])  # Allowing cross domain headers can be used to identify sources and other functions.


@app.post("/tts_bark/")
async def tts_bark(item: schemas.generate_web):
    time_start = time.time()
    text = item.text
    print(f"{text=}")
    try:
        sentences = nltk.sent_tokenize(text)
        idx = 1
        wavs = []
        for s in sentences:
            audio_array = generate_audio(s, history_prompt="en_speaker_8", text_temp=0.6, waveform_temp=0.6)
            fname = f"tmp-{idx}.wav"
            sf.write(fname, audio_array, SAMPLE_RATE)
            idx += 1
            wavs.append(fname)
        file_name_pre = f"out-{time.time()}"
        file_name_wav = file_name_pre + ".wav"
        file_name_ogg = file_name_pre + ".ogg"
        concatenate_wavs(wavs, file_name_wav)

        # convert to OGG
        os.system("ffmpeg -i " + file_name_wav + " -c:a libopus -b:a 64k -y " + file_name_ogg)

        with open(file_name_ogg, "rb") as f:
            audio_content = f.read()
        base64_audio = base64.b64encode(audio_content).decode("utf-8")
        res = {"file_base64": base64_audio,
               "audio_text": text,
               "file_name": file_name_ogg,
               }
        print_log(item, res, time_start)
        os.remove(file_name_wav)
        os.remove(file_name_ogg)

        return res
    except Exception as err:
        res = {"code": 9, "msg": "api error", "err": str(err), "traceback": traceback.format_exc()}
        print_log(item, res, time_start)
        return res

if __name__ == '__main__':

    print_env(server_port)
    uvicorn.run(app="main:app", host="0.0.0.0", port=server_port, reload=False)