File size: 7,661 Bytes
da8e0c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from typing import *
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
import asyncio
import logging
from pydantic import BaseModel, Field
import uvicorn
from voiceapi.tts import TTSResult, start_tts_stream, TTSStream
from voiceapi.asr import start_asr_stream, ASRStream, ASRResult
import logging
import argparse
import os
app = FastAPI()
logger = logging.getLogger(__file__)
@app.websocket("/asr")
async def websocket_asr(websocket: WebSocket,
samplerate: int = Query(16000, title="Sample Rate",
description="The sample rate of the audio."),):
await websocket.accept()
asr_stream: ASRStream = await start_asr_stream(samplerate, args)
if not asr_stream:
logger.error("failed to start ASR stream")
await websocket.close()
return
async def task_recv_pcm():
while True:
pcm_bytes = await websocket.receive_bytes()
if not pcm_bytes:
return
await asr_stream.write(pcm_bytes)
async def task_send_result():
while True:
result: ASRResult = await asr_stream.read()
if not result:
return
await websocket.send_json(result.to_dict())
try:
await asyncio.gather(task_recv_pcm(), task_send_result())
except WebSocketDisconnect:
logger.info("asr: disconnected")
finally:
await asr_stream.close()
@app.websocket("/tts")
async def websocket_tts(websocket: WebSocket,
samplerate: int = Query(16000,
title="Sample Rate",
description="The sample rate of the generated audio."),
interrupt: bool = Query(True,
title="Interrupt",
description="Interrupt the current TTS stream when a new text is received."),
sid: int = Query(0,
title="Speaker ID",
description="The ID of the speaker to use for TTS."),
chunk_size: int = Query(1024,
title="Chunk Size",
description="The size of the chunk to send to the client."),
speed: float = Query(1.0,
title="Speed",
description="The speed of the generated audio."),
split: bool = Query(True,
title="Split",
description="Split the text into sentences.")):
await websocket.accept()
tts_stream: TTSStream = None
async def task_recv_text():
nonlocal tts_stream
while True:
text = await websocket.receive_text()
if not text:
return
if interrupt or not tts_stream:
if tts_stream:
await tts_stream.close()
logger.info("tts: stream interrupt")
tts_stream = await start_tts_stream(sid, samplerate, speed, args)
if not tts_stream:
logger.error("tts: failed to allocate tts stream")
await websocket.close()
return
logger.info(f"tts: received: {text} (split={split})")
await tts_stream.write(text, split)
async def task_send_pcm():
nonlocal tts_stream
while not tts_stream:
# wait for tts stream to be created
await asyncio.sleep(0.1)
while True:
result: TTSResult = await tts_stream.read()
if not result:
return
if result.finished:
await websocket.send_json(result.to_dict())
else:
for i in range(0, len(result.pcm_bytes), chunk_size):
await websocket.send_bytes(result.pcm_bytes[i:i+chunk_size])
try:
await asyncio.gather(task_recv_text(), task_send_pcm())
except WebSocketDisconnect:
logger.info("tts: disconnected")
finally:
if tts_stream:
await tts_stream.close()
class TTSRequest(BaseModel):
text: str = Field(..., title="Text",
description="The text to be converted to speech.",
examples=["Hello, world!"])
sid: int = Field(0, title="Speaker ID",
description="The ID of the speaker to use for TTS.")
samplerate: int = Field(16000, title="Sample Rate",
description="The sample rate of the generated audio.")
speed: float = Field(1.0, title="Speed",
description="The speed of the generated audio.")
@ app.post("/tts",
description="Generate speech audio from text.",
response_class=StreamingResponse, responses={200: {"content": {"audio/wav": {}}}})
async def tts_generate(req: TTSRequest):
if not req.text:
raise HTTPException(status_code=400, detail="text is required")
tts_stream = await start_tts_stream(req.sid, req.samplerate, req.speed, args)
if not tts_stream:
raise HTTPException(
status_code=500, detail="failed to start TTS stream")
r = await tts_stream.generate(req.text)
return StreamingResponse(r, media_type="audio/wav")
if __name__ == "__main__":
models_root = './models'
for d in ['.', '..', '../..']:
if os.path.isdir(f'{d}/models'):
models_root = f'{d}/models'
break
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--addr", type=str,
default="0.0.0.0", help="serve address")
parser.add_argument("--asr-provider", type=str,
default="cpu", help="asr provider, cpu or cuda")
parser.add_argument("--tts-provider", type=str,
default="cpu", help="tts provider, cpu or cuda")
parser.add_argument("--threads", type=int, default=2,
help="number of threads")
parser.add_argument("--models-root", type=str, default=models_root,
help="model root directory")
parser.add_argument("--asr-model", type=str, default='sensevoice',
help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en")
parser.add_argument("--asr-lang", type=str, default='zh',
help="ASR language, zh, en, ja, ko, yue")
parser.add_argument("--tts-model", type=str, default='vits-zh-hf-theresa',
help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en")
args = parser.parse_args()
if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda':
logger.warning(
"vits-melo-tts-zh_en does not support CUDA fallback to CPU")
args.tts_provider = 'cpu'
app.mount("/", app=StaticFiles(directory="./assets", html=True), name="assets")
logging.basicConfig(format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s',
level=logging.INFO)
uvicorn.run(app, host=args.addr, port=args.port)
|