fishspeech2 / tools /schema.py
pineconeT94's picture
first commit
8b14bed
import os
import queue
from dataclasses import dataclass
from typing import Annotated, Literal, Optional
import torch
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
from pydantic.functional_validators import SkipValidation
from fish_speech.conversation import Message, TextPart, VQPart
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
codes: SkipValidation[list[list[int]]]
class ServeTextPart(BaseModel):
type: Literal["text"] = "text"
text: str
class ServeAudioPart(BaseModel):
type: Literal["audio"] = "audio"
audio: bytes
@dataclass
class ASRPackRequest:
audio: torch.Tensor
result_queue: queue.Queue
language: str
class ServeASRRequest(BaseModel):
# The audio should be an uncompressed PCM float16 audio
audios: list[bytes]
sample_rate: int = 44100
language: Literal["zh", "en", "ja", "auto"] = "auto"
class ServeASRTranscription(BaseModel):
text: str
duration: float
huge_gap: bool
class ServeASRSegment(BaseModel):
text: str
start: float
end: float
class ServeTimedASRResponse(BaseModel):
text: str
segments: list[ServeASRSegment]
duration: float
class ServeASRResponse(BaseModel):
transcriptions: list[ServeASRTranscription]
class ServeMessage(BaseModel):
role: Literal["system", "assistant", "user"]
parts: list[ServeVQPart | ServeTextPart]
def to_conversation_message(self):
new_message = Message(role=self.role, parts=[])
for part in self.parts:
if isinstance(part, ServeTextPart):
new_message.parts.append(TextPart(text=part.text))
elif isinstance(part, ServeVQPart):
new_message.parts.append(
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
)
else:
raise ValueError(f"Unsupported part type: {part}")
return new_message
class ServeRequest(BaseModel):
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
max_new_tokens: int = 1024
top_p: float = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
streaming: bool = False
num_samples: int = 1
early_stop_threshold: float = 1.0
class ServeVQGANEncodeRequest(BaseModel):
# The audio here should be in wav, mp3, etc
audios: list[bytes]
class ServeVQGANEncodeResponse(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeRequest(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeResponse(BaseModel):
# The audio here should be in PCM float16 format
audios: list[bytes]
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeForwardMessage(BaseModel):
role: str
content: str
class ServeResponse(BaseModel):
messages: list[ServeMessage]
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] = {}
class ServeStreamDelta(BaseModel):
role: Literal["system", "assistant", "user"] | None = None
part: ServeVQPart | ServeTextPart | None = None
class ServeStreamResponse(BaseModel):
sample_id: int = 0
delta: ServeStreamDelta | None = None
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] | None = None
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeChatRequestV1(BaseModel):
model: str = "llama3-8b"
messages: list[ServeForwardMessage] = []
audio: bytes | None = None
temperature: float = 1.0
top_p: float = 1.0
max_tokens: int = 256
voice: str = "jessica"
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on-demand", "never"] = "never"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7