fishspeech2 / tools /post_api.py
pineconeT94's picture
first commit
8b14bed
import argparse
import base64
import wave
import ormsgpack
import pyaudio
import requests
from pydub import AudioSegment
from pydub.playback import play
from tools.file import audio_to_bytes, read_ref_text
from tools.schema import ServeReferenceAudio, ServeTTSRequest
def parse_args():
parser = argparse.ArgumentParser(
description="Send a WAV file and text to a server and receive synthesized audio.",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--url",
"-u",
type=str,
default="http://127.0.0.1:8080/v1/tts",
help="URL of the server",
)
parser.add_argument(
"--text", "-t", type=str, required=True, help="Text to be synthesized"
)
parser.add_argument(
"--reference_id",
"-id",
type=str,
default=None,
help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
)
parser.add_argument(
"--reference_audio",
"-ra",
type=str,
nargs="+",
default=None,
help="Path to the audio file",
)
parser.add_argument(
"--reference_text",
"-rt",
type=str,
nargs="+",
default=None,
help="Reference text for voice synthesis",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="generated_audio",
help="Output audio file name",
)
parser.add_argument(
"--play",
type=bool,
default=True,
help="Whether to play audio after receiving data",
)
parser.add_argument("--normalize", type=bool, default=True)
parser.add_argument(
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
)
parser.add_argument(
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
)
parser.add_argument("--opus_bitrate", type=int, default=-1000)
parser.add_argument(
"--latency",
type=str,
default="normal",
choices=["normal", "balanced"],
help="Used in api.fish.audio/v1/tts",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=0,
help="Maximum new tokens to generate. \n0 means no limit.",
)
parser.add_argument(
"--chunk_length", type=int, default=200, help="Chunk length for synthesis"
)
parser.add_argument(
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.2,
help="Repetition penalty for synthesis",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Temperature for sampling"
)
parser.add_argument(
"--streaming", type=bool, default=False, help="Enable streaming response"
)
parser.add_argument(
"--channels", type=int, default=1, help="Number of audio channels"
)
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
parser.add_argument(
"--use_memory_cache",
type=str,
default="never",
choices=["on-demand", "never"],
help="Cache encoded references codes in memory.\n"
"If `on-demand`, the server will use cached encodings\n "
"instead of encoding reference audio again.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="`None` means randomized inference, otherwise deterministic.\n"
"It can't be used for fixing a timbre.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
idstr: str | None = args.reference_id
# priority: ref_id > [{text, audio},...]
if idstr is None:
ref_audios = args.reference_audio
ref_texts = args.reference_text
if ref_audios is None:
byte_audios = []
else:
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
if ref_texts is None:
ref_texts = []
else:
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
else:
byte_audios = []
ref_texts = []
pass # in api.py
data = {
"text": args.text,
"references": [
ServeReferenceAudio(audio=ref_audio, text=ref_text)
for ref_text, ref_audio in zip(ref_texts, byte_audios)
],
"reference_id": idstr,
"normalize": args.normalize,
"format": args.format,
"mp3_bitrate": args.mp3_bitrate,
"opus_bitrate": args.opus_bitrate,
"max_new_tokens": args.max_new_tokens,
"chunk_length": args.chunk_length,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"temperature": args.temperature,
"streaming": args.streaming,
"use_memory_cache": args.use_memory_cache,
"seed": args.seed,
}
pydantic_data = ServeTTSRequest(**data)
response = requests.post(
args.url,
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
stream=args.streaming,
headers={
"authorization": "Bearer YOUR_API_KEY",
"content-type": "application/msgpack",
},
)
if response.status_code == 200:
if args.streaming:
p = pyaudio.PyAudio()
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
stream = p.open(
format=audio_format, channels=args.channels, rate=args.rate, output=True
)
wf = wave.open(f"{args.output}.wav", "wb")
wf.setnchannels(args.channels)
wf.setsampwidth(p.get_sample_size(audio_format))
wf.setframerate(args.rate)
stream_stopped_flag = False
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
stream.write(chunk)
wf.writeframesraw(chunk)
else:
if not stream_stopped_flag:
stream.stop_stream()
stream_stopped_flag = True
finally:
stream.close()
p.terminate()
wf.close()
else:
audio_content = response.content
audio_path = f"{args.output}.{args.format}"
with open(audio_path, "wb") as audio_file:
audio_file.write(audio_content)
audio = AudioSegment.from_file(audio_path, format=args.format)
if args.play:
play(audio)
print(f"Audio has been saved to '{audio_path}'.")
else:
print(f"Request failed with status code {response.status_code}")
print(response.json())