File size: 2,796 Bytes
92740f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0195d32
92740f3
 
 
 
 
 
 
 
 
0195d32
92740f3
 
 
 
 
 
 
 
 
0195d32
92740f3
0195d32
 
 
92740f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) 2024 NVIDIA CORPORATION. 
#   Licensed under the MIT license.

import os
import string
import yaml
from copy import deepcopy

import torch
from transformers import AutoTokenizer, set_seed 
set_seed(0)

from data import AudioTextDataProcessor
from src.factory import create_model_and_transforms


def prepare_tokenizer(model_config):
    tokenizer_path = model_config['tokenizer_path']
    cache_dir = model_config['cache_dir']
    text_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        local_files_only=False,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )
    text_tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
    )
    if text_tokenizer.pad_token is None:
        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    if text_tokenizer.sep_token is None:
        text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
    return text_tokenizer


def prepare_model(model_config, clap_config, checkpoint_path, device=0):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # disable the tokenizer parallelism warning
    model, tokenizer = create_model_and_transforms(
        **model_config,
        clap_config=clap_config,
        use_local_files=False,
        gradient_checkpointing=False,
        freeze_lm_embeddings=False,
    )
    model.eval()
    model = model.to(device)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model_state_dict = checkpoint["model_state_dict"]
    model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict, False)

    return model


def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0):
    filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
    audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
    audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
    input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()

    media_token_id = tokenizer.encode("<audio>")[-1]
    eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
    sep_token_id = tokenizer.sep_token_id
    eos_token_id = tokenizer.eos_token_id
    
    outputs = model.generate(
        audio_x=audio_clips.unsqueeze(0),
        audio_x_mask=audio_embed_mask.unsqueeze(0),
        lang_x=input_ids.unsqueeze(0),
        eos_token_id=eos_token_id,
        max_new_tokens=128,
        **inference_kwargs,
    )

    outputs_decoded = [
        tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
    ]

    return outputs_decoded