whispervq / services /AudioTokenizerService.py
jan-hq's picture
Upload folder using huggingface_hub
b46f992 verified
import io
import os
from huggingface_hub import hf_hub_download
from models.audio import AudioFormat, FORMAT_BACKENDS
import tempfile
import logging
import torchaudio
from fastapi import HTTPException
from fastapi.responses import JSONResponse
import torch
from typing import Tuple
from utils.custom_component import CustomRQBottleneckTransformer
logger = logging.getLogger(__name__)
class AudioTokenizerService:
def __init__(self):
self.available_backends = torchaudio.list_audio_backends()
logger.info(f"Available backends: {self.available_backends}")
main_directory = os.path.dirname(
os.path.dirname(os.path.realpath(__file__)))
# Verify ffmpeg support
self.has_ffmpeg = "ffmpeg" in self.available_backends
if not self.has_ffmpeg:
logger.warning(
"FFMPEG backend not available. Some formats may not be supported")
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists(main_directory+"/whisper-vq-stoks-v3-7lang-fixed.model"):
hf_hub_download(
repo_id="jan-hq/WhisperVQ",
filename="whisper-vq-stoks-v3-7lang-fixed.model",
local_dir=main_directory,
)
self.vq_model = CustomRQBottleneckTransformer.load_vq_only(
main_directory +
"/whisper-vq-stoks-v3-7lang-fixed.model"
).to(device)
self.vq_model.load_encoder(device)
self.vq_model.eval()
# vq_model = torch.compile(vq_model)
def _get_best_backend(self, format: AudioFormat) -> str:
"""Determine the best backend for the given format"""
supported_backends = FORMAT_BACKENDS[format]
for backend in supported_backends:
if backend in self.available_backends:
return backend
raise ValueError(f"No available backend supports format {format}")
def load_audio(
self,
file_obj: bytes,
format: AudioFormat,
target_sr: int = 16000
) -> Tuple[torch.Tensor, int]:
"""
Load audio from bytes object with format handling
Args:
file_obj: Audio file bytes
format: Audio format enum
target_sr: Target sample rate (default: 16000)
Returns:
Tuple[torch.Tensor, int]: Audio tensor and sample rate
"""
try:
# Get appropriate backend
backend = self._get_best_backend(format)
torchaudio.set_audio_backend(backend)
logger.info(f"Using {backend} backend for {format} format")
if format == AudioFormat.PCM:
# Handle raw PCM
wav = torch.frombuffer(file_obj, dtype=torch.int16)
wav = wav.float() / 32768.0 # Normalize to [-1, 1]
wav = wav.unsqueeze(0) # Add channel dimension
sr = target_sr
else:
# For formats that might need ffmpeg processing
if os.name == "nt": # for windows
wav, sr = torchaudio.load(io.BytesIO(file_obj))
else:
with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file:
# Write bytes to temporary file
temp_file.write(file_obj)
temp_file.flush()
# Load audio
wav, sr = torchaudio.load(temp_file.name)
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
# Resample if needed
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
sr = target_sr
return wav, sr
except Exception as e:
logger.error(f"Error loading audio: {e}")
raise HTTPException(
status_code=400,
detail=f"Error processing {format} audio: {str(e)}"
)
def get_format_info(self) -> dict:
"""Get information about supported formats"""
supported_formats = {}
for format in AudioFormat:
try:
backend = self._get_best_backend(format)
supported_formats[format] = {
"supported": True,
"backend": backend
}
except ValueError:
supported_formats[format] = {
"supported": False,
"backend": None
}
return supported_formats
def tokenize(self, audio_data: bytes, format: AudioFormat = "wav"):
try:
wav, sr = self.load_audio(audio_data, format)
# Ensure we're using CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
wav = wav.to(device)
# Generate tokens
with torch.no_grad():
codes = self.vq_model.encode_audio(wav)
codes = codes[0].cpu().tolist()
# Format result
result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
return JSONResponse(content={
"model_name": "whisper-vq-stoks-v3-7lang-fixed.model",
"tokens": f'<|sound_start|>{result}<|sound_end|>',
"format": format,
"sample_rate": sr,
"backend_used": self._get_best_backend(format)
})
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(
status_code=500,
detail=f"Error processing request: {str(e)}"
)
_audio_tokenizer_service = None
def get_audio_tokenizer_service():
global _audio_tokenizer_service
if _audio_tokenizer_service is None:
_audio_tokenizer_service = AudioTokenizerService()
return _audio_tokenizer_service