chatterbox / handler.py
aiplexdeveloper's picture
Update handler.py
9b9cb9f verified
raw
history blame
1.6 kB
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from typing import Dict, Any, List
import soundfile as sf
import io
import base64
from huggingface_hub import hf_hub_download
class EndpointHandler:
def __init__(self, path: str = ""):
try:
self.model = ChatterboxTTS.from_pretrained(device="cuda")
except Exception as e:
raise RuntimeError(f"[ERROR] Failed to load model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #, data: Dict[str, Any]) -> List[Dict[str, Any]]
try:
inputs = data.get("inputs", {})
text = inputs.get("text")
exaggeration = inputs.get("exaggeration", 0.3)
cfg_weight = inputs.get("cfg_weight", 0.5)
print(exaggeration, cfg_weight)
AUDIO_PROMPT_PATH=hf_hub_download(repo_id="aiplexdeveloper/chatterbox", filename="arjun_das_output_audio.mp3")
wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight)
buffer = io.BytesIO()
sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV')
buffer.seek(0)
# Encode to base64
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
audio_length_seconds = len(wav) / self.model.sr
return [{"audio_base64": audio_base64, "audio_length_seconds":audio_length_seconds}]
except Exception as e:
print(f"[ERROR] Inference failed: {e}")
return [{"error": str(e)}]