|
import io |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from models.audio import AudioFormat, FORMAT_BACKENDS |
|
import tempfile |
|
import logging |
|
import torchaudio |
|
from fastapi import HTTPException |
|
from fastapi.responses import JSONResponse |
|
import torch |
|
from typing import Tuple |
|
from utils.custom_component import CustomRQBottleneckTransformer |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AudioTokenizerService: |
|
def __init__(self): |
|
self.available_backends = torchaudio.list_audio_backends() |
|
logger.info(f"Available backends: {self.available_backends}") |
|
main_directory = os.path.dirname( |
|
os.path.dirname(os.path.realpath(__file__))) |
|
|
|
|
|
self.has_ffmpeg = "ffmpeg" in self.available_backends |
|
if not self.has_ffmpeg: |
|
logger.warning( |
|
"FFMPEG backend not available. Some formats may not be supported") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if not os.path.exists(main_directory+"/whisper-vq-stoks-v3-7lang-fixed.model"): |
|
hf_hub_download( |
|
repo_id="jan-hq/WhisperVQ", |
|
filename="whisper-vq-stoks-v3-7lang-fixed.model", |
|
local_dir=main_directory, |
|
) |
|
self.vq_model = CustomRQBottleneckTransformer.load_vq_only( |
|
main_directory + |
|
"/whisper-vq-stoks-v3-7lang-fixed.model" |
|
).to(device) |
|
self.vq_model.load_encoder(device) |
|
self.vq_model.eval() |
|
|
|
|
|
def _get_best_backend(self, format: AudioFormat) -> str: |
|
"""Determine the best backend for the given format""" |
|
supported_backends = FORMAT_BACKENDS[format] |
|
for backend in supported_backends: |
|
if backend in self.available_backends: |
|
return backend |
|
raise ValueError(f"No available backend supports format {format}") |
|
|
|
def load_audio( |
|
self, |
|
file_obj: bytes, |
|
format: AudioFormat, |
|
target_sr: int = 16000 |
|
) -> Tuple[torch.Tensor, int]: |
|
""" |
|
Load audio from bytes object with format handling |
|
|
|
Args: |
|
file_obj: Audio file bytes |
|
format: Audio format enum |
|
target_sr: Target sample rate (default: 16000) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, int]: Audio tensor and sample rate |
|
""" |
|
try: |
|
|
|
backend = self._get_best_backend(format) |
|
torchaudio.set_audio_backend(backend) |
|
logger.info(f"Using {backend} backend for {format} format") |
|
|
|
if format == AudioFormat.PCM: |
|
|
|
wav = torch.frombuffer(file_obj, dtype=torch.int16) |
|
wav = wav.float() / 32768.0 |
|
wav = wav.unsqueeze(0) |
|
sr = target_sr |
|
else: |
|
|
|
if os.name == "nt": |
|
wav, sr = torchaudio.load(io.BytesIO(file_obj)) |
|
else: |
|
with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file: |
|
|
|
temp_file.write(file_obj) |
|
temp_file.flush() |
|
|
|
|
|
wav, sr = torchaudio.load(temp_file.name) |
|
|
|
|
|
if wav.shape[0] > 1: |
|
wav = torch.mean(wav, dim=0, keepdim=True) |
|
|
|
|
|
if sr != target_sr: |
|
wav = torchaudio.functional.resample(wav, sr, target_sr) |
|
sr = target_sr |
|
|
|
return wav, sr |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading audio: {e}") |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Error processing {format} audio: {str(e)}" |
|
) |
|
|
|
def get_format_info(self) -> dict: |
|
"""Get information about supported formats""" |
|
supported_formats = {} |
|
for format in AudioFormat: |
|
try: |
|
backend = self._get_best_backend(format) |
|
supported_formats[format] = { |
|
"supported": True, |
|
"backend": backend |
|
} |
|
except ValueError: |
|
supported_formats[format] = { |
|
"supported": False, |
|
"backend": None |
|
} |
|
return supported_formats |
|
|
|
def tokenize(self, audio_data: bytes, format: AudioFormat = "wav"): |
|
try: |
|
wav, sr = self.load_audio(audio_data, format) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
wav = wav.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
codes = self.vq_model.encode_audio(wav) |
|
codes = codes[0].cpu().tolist() |
|
|
|
|
|
result = ''.join(f'<|sound_{num:04d}|>' for num in codes) |
|
|
|
return JSONResponse(content={ |
|
"model_name": "whisper-vq-stoks-v3-7lang-fixed.model", |
|
"tokens": f'<|sound_start|>{result}<|sound_end|>', |
|
"format": format, |
|
"sample_rate": sr, |
|
"backend_used": self._get_best_backend(format) |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing request: {e}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error processing request: {str(e)}" |
|
) |
|
|
|
|
|
_audio_tokenizer_service = None |
|
|
|
|
|
def get_audio_tokenizer_service(): |
|
global _audio_tokenizer_service |
|
if _audio_tokenizer_service is None: |
|
_audio_tokenizer_service = AudioTokenizerService() |
|
return _audio_tokenizer_service |
|
|