import random import torch from slam_llm.utils.model_utils import get_custom_model_factory from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift import whisper import numpy as np from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME import os from omegaconf import OmegaConf from huggingface_hub import hf_hub_download from typing import Callable def update_progress(progress_callback: Callable[[str], None] | None, message: str): if progress_callback: progress_callback(message) def pull_model_ckpt(): if not os.path.exists(CKPT_LOCAL_DIR): os.makedirs(CKPT_LOCAL_DIR) if os.path.exists(CKPT_PATH): return hf_hub_download( repo_id=CKPT_REPO, filename=CKPT_NAME, local_dir=CKPT_LOCAL_DIR, token=os.getenv("HF_TOKEN"), ) pull_model_ckpt() def extract_audio_feature(audio_path, mel_size): print("Extracting audio features from", audio_path) audio_raw = whisper.load_audio(audio_path) audio_raw = whisper.pad_or_trim(audio_raw) audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0) audio_length = (audio_mel.shape[0] + 1) // 2 audio_length = audio_length // 5 audio_res = audio_mel return audio_res, audio_length def get_input_ids(length, special_token_a, special_token_t, vocab_config): input_ids = [] for i in range(vocab_config.code_layer): input_ids_item = [] input_ids_item.append(layershift(vocab_config.input_a, i)) input_ids_item += [layershift(vocab_config.pad_a, i)] * length input_ids_item += [ (layershift(vocab_config.eoa, i)), layershift(special_token_a, i), ] input_ids.append(torch.tensor(input_ids_item).unsqueeze(0)) input_id_T = torch.tensor( [vocab_config.input_t] + [vocab_config.pad_t] * length + [vocab_config.eot, special_token_t] ) input_ids.append(input_id_T.unsqueeze(0)) return input_ids def generate_from_wav( wav_path, model, codec_decoder, dataset_config, decode_config, device ): mel_size = dataset_config.mel_size prompt = dataset_config.prompt prompt_template = "USER: {}\n ASSISTANT: " vocab_config = dataset_config.vocab_config special_token_a = vocab_config.answer_a special_token_t = vocab_config.answer_t code_layer = vocab_config.code_layer task_type = dataset_config.task_type audio_mel, audio_length = extract_audio_feature(wav_path, mel_size) prompt = prompt_template.format(prompt) prompt_ids = model.tokenizer.encode(prompt) prompt_length = len(prompt_ids) prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) example_ids = get_input_ids( audio_length + prompt_length, special_token_a, special_token_t, vocab_config ) text_layer = example_ids[code_layer] text_layer = torch.cat( ( text_layer[:, : audio_length + 1], prompt_ids.unsqueeze(0), text_layer[:, -2:], ), dim=1, ) # <bos> <audio> <prompt> <eos> <task> example_ids[code_layer] = text_layer input_length = audio_length example_mask = example_ids[0][0].ge(-1) example_ids = torch.stack(example_ids).squeeze() input_ids = example_ids.unsqueeze(0).to(device) attention_mask = example_mask.unsqueeze(0).to(device) audio_mel = audio_mel.unsqueeze(0).to(device) input_length = torch.tensor([input_length]).to(device) audio_length = torch.tensor([audio_length]).to(device) task_type = [task_type] modality_mask = torch.zeros_like(attention_mask) padding_left = 1 # +1 for <bos> modality_mask[0, padding_left : padding_left + audio_length] = True batch = { "input_ids": input_ids, "attention_mask": attention_mask, "audio_mel": audio_mel, "input_length": input_length, "audio_length": audio_length, "modality_mask": modality_mask, "task_types": task_type, } model_outputs = model.generate(**batch, **decode_config) text_outputs = model_outputs[7] audio_outputs = model_outputs[:7] output_text = model.tokenizer.decode( text_outputs, add_special_tokens=False, skip_special_tokens=True ) if decode_config.decode_text_only: return None, output_text audio_tokens = [audio_outputs[layer] for layer in range(7)] audiolist = reconscruct_snac(audio_tokens) audio = reconstruct_tensors(audiolist) with torch.inference_mode(): audio_hat = codec_decoder.decode(audio) return audio_hat, output_text model = None codec_decoder = None device = None def generate( wav_path: str, progress_callback: Callable[[str], None] | None = None ) -> tuple[np.ndarray, int | float]: global model, codec_decoder, device config = OmegaConf.structured(InferenceConfig()) train_config, model_config, dataset_config, decode_config = ( config.train_config, config.model_config, config.dataset_config, config.decode_config, ) torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) random.seed(train_config.seed) if model is None or codec_decoder is None or device is None: update_progress(progress_callback, "Loading model") model_factory = get_custom_model_factory(model_config) model, _ = model_factory(train_config, model_config, CKPT_PATH) codec_decoder = model.codec_decoder device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() update_progress(progress_callback, "Generating") output_wav, output_text = generate_from_wav( wav_path, model, codec_decoder, dataset_config, decode_config, device ) return output_wav.squeeze().cpu().numpy(), 24000 if __name__ == "__main__": wav_path = "sample.wav" generate(wav_path)