import logging import os import uuid import torch import torchaudio from .constants import ( AUD_CONTEXT_TOKEN, AUD_END_TOKEN, AUD_START_TOKEN, AUD_TAG_TOKEN, BOX_END_TOKEN, BOX_START_TOKEN, IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, IMG_TAG_TOKEN, PATCH_CONTEXT_TOKEN, PATCH_END_TOKEN, PATCH_START_TOKEN, QUAD_END_TOKEN, QUAD_START_TOKEN, REF_END_TOKEN, REF_START_TOKEN, VID_CONTEXT_TOKEN, VID_END_TOKEN, VID_START_TOKEN, VID_TAG_TOKEN, ) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def update_tokenizer_for_sensevoice_sparktts(tokenizer): token_list = [ IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, AUD_START_TOKEN, AUD_END_TOKEN, AUD_CONTEXT_TOKEN, QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN, REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN, AUD_TAG_TOKEN, ] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) token_list = [f"<|audio_{i}|>" for i in range(8192)] num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False) # logger.info(f"tokenizer {tokenizer}") return tokenizer class SenseVoiceSparkTTSTokenizer: def __init__(self, model_name_or_path, rank=None): self.model_name_or_path = model_name_or_path if rank is None and torch.distributed.is_initialized(): rank = torch.distributed.get_rank() self.rank = rank % 8 else: self.rank = rank logger.info(f"{self.rank=}") self.sampling_rate = 16000 self.is_discrete = True self.is_contiguous = True # T A T A text_audio_interval_ratio = [1, 10, 1, 10] self.text_audio_interval_ratio = text_audio_interval_ratio def load_model(self): if hasattr(self, "model"): return if self.rank is not None: self.device = f"cuda:{self.rank}" torch.cuda.set_device(self.rank) else: self.device = "cpu" logger.info(f"{self.device=}") logger.info("Loading SenseVoiceSmall") from funasr.models.sense_voice.model import SenseVoiceSmall model_dir = "/data/models/FunAudioLLM/SenseVoiceSmall/" _, self.kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=self.device) logger.info("Loading SenseVoiceSmall Done") logger.info("Loading BiCodecTokenizer") from sparktts.models.audio_tokenizer import BiCodecTokenizer model_dir = "/data/models/SparkAudio/Spark-TTS-0.5B/" # import time # import random # time.sleep(self.rank * 2 + random.randint(3, 9)) self.model = BiCodecTokenizer(model_dir, device=self.device) logger.info("Loading BiCodecTokenizer Done") def encode(self, audio_path, is_discrete=False, is_contiguous=True, **kwargs): if not hasattr(self, "model"): self.load_model() assert not (is_discrete and is_contiguous) assert is_discrete or is_contiguous if is_discrete: global_token_ids, semantic_token_ids = self.model.tokenize(audio_path) semantic_token_ids = semantic_token_ids[0].cpu().tolist() return semantic_token_ids if is_contiguous: from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank audio, sampling_rate = torchaudio.load(audio_path) audio = audio.mean(0) resampler = torchaudio.transforms.Resample( orig_freq=sampling_rate, new_freq=self.sampling_rate ) audio = resampler(audio[None, :])[0, :] # audio = audio.to(self.device) frontend = self.kwargs["frontend"] speech, speech_lengths = extract_fbank(audio, data_type="sound", frontend=frontend) speech = speech[0] # print(f"{speech_lengths=}") # print(f"{speech.size()=}") return speech def decode(self, prompt_speech_token, source_speech_16k=None): if not hasattr(self, "model"): self.load_model() semantic_token_ids = torch.tensor(prompt_speech_token, dtype=torch.long).unsqueeze(0) # print(f"{semantic_token_ids=}") if source_speech_16k is None: global_token_ids = torch.zeros((1, 1, 32), dtype=torch.long) else: global_token_ids, _ = self.model.tokenize(source_speech_16k) # print(f"{source_speech_16k=}") print(f"{global_token_ids=}") audio = self.model.detokenize( global_token_ids.to(self.device).squeeze(0), semantic_token_ids.to(self.device), ) print(f"{audio=}") # audio = torch.tensor(audio).unsqueeze(0) audio = torch.tensor(audio) return audio def apply_to_role(self, role, **kwargs): is_discrete = kwargs.get("is_discrete", False) if is_discrete and role in ["assistant", "gpt"]: return True is_contiguous = kwargs.get("is_contiguous", False) if is_contiguous and role in ["user", "human"]: return True return False