|
import os, sys |
|
|
|
if sys.platform == "darwin": |
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
|
|
|
now_dir = os.getcwd() |
|
sys.path.append(now_dir) |
|
|
|
from dataclasses import asdict |
|
import argparse |
|
import torch |
|
from tqdm import tqdm |
|
from ChatTTS.model.dvae import DVAE |
|
from ChatTTS.config import Config |
|
from vocos import Vocos |
|
from vocos.pretrained import instantiate_class |
|
import torch.jit as jit |
|
|
|
from gpt import GPT |
|
|
|
|
|
torch.cuda.is_available = lambda: False |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--gpt", action="store_true", help="trace gpt") |
|
parser.add_argument("--decoder", action="store_true", help="trace decoder") |
|
parser.add_argument("--vocos", action="store_true", help="trace vocos") |
|
parser.add_argument( |
|
"--pth_dir", default="./assets", type=str, help="path to the pth model directory" |
|
) |
|
parser.add_argument( |
|
"--out_dir", default="./tmp", type=str, help="path to output directory" |
|
) |
|
|
|
args = parser.parse_args() |
|
chattts_config = Config() |
|
|
|
|
|
def export_gpt(): |
|
gpt_model = GPT(gpt_config=asdict(chattts_config.gpt), use_flash_attn=False).eval() |
|
gpt_model.from_pretrained(asdict(chattts_config.path)["gpt_ckpt_path"]) |
|
gpt_model = gpt_model.eval() |
|
for param in gpt_model.parameters(): |
|
param.requires_grad = False |
|
|
|
config = gpt_model.gpt.config |
|
layers = gpt_model.gpt.layers |
|
model_norm = gpt_model.gpt.norm |
|
|
|
NUM_OF_LAYERS = config.num_hidden_layers |
|
HIDDEN_SIZE = config.hidden_size |
|
NUM_ATTENTION_HEADS = config.num_attention_heads |
|
NUM_KEY_VALUE_HEADS = config.num_key_value_heads |
|
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS |
|
TEXT_VOCAB_SIZE = gpt_model.emb_text.weight.shape[0] |
|
AUDIO_VOCAB_SIZE = gpt_model.emb_code[0].weight.shape[0] |
|
SEQ_LENGTH = 512 |
|
|
|
folder = os.path.join(args.out_dir, "gpt") |
|
os.makedirs(folder, exist_ok=True) |
|
|
|
for param in gpt_model.emb_text.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in gpt_model.emb_code.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in gpt_model.head_code.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in gpt_model.head_text.parameters(): |
|
param.requires_grad = False |
|
|
|
class EmbeddingText(torch.nn.Module): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input_ids): |
|
return gpt_model.emb_text(input_ids) |
|
|
|
def convert_embedding_text(): |
|
model = EmbeddingText() |
|
input_ids = torch.tensor([range(SEQ_LENGTH)]) |
|
|
|
torch.onnx.export( |
|
model, |
|
(input_ids), |
|
f"{folder}/embedding_text.onnx", |
|
verbose=False, |
|
input_names=["input_ids"], |
|
output_names=["input_embed"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class EmbeddingCode(torch.nn.Module): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input_ids): |
|
input_ids = input_ids.unsqueeze(2).expand( |
|
-1, -1, gpt_model.num_vq |
|
) |
|
code_emb = [ |
|
gpt_model.emb_code[i](input_ids[:, :, i]) |
|
for i in range(gpt_model.num_vq) |
|
] |
|
return torch.stack(code_emb, 2).sum(2) |
|
|
|
def convert_embedding_code(): |
|
model = EmbeddingCode() |
|
input_ids = torch.tensor([range(SEQ_LENGTH)]) |
|
|
|
torch.onnx.export( |
|
model, |
|
(input_ids), |
|
f"{folder}/embedding_code.onnx", |
|
verbose=False, |
|
input_names=["input_ids"], |
|
output_names=["input_embed"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class EmbeddingCodeCache(torch.nn.Module): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input_ids): |
|
code_emb = [ |
|
gpt_model.emb_code[i](input_ids[:, :, i]) |
|
for i in range(gpt_model.num_vq) |
|
] |
|
return torch.stack(code_emb, 2).sum(2) |
|
|
|
def convert_embedding_code_cache(): |
|
model = EmbeddingCodeCache() |
|
input_ids = torch.tensor( |
|
[[[416, 290, 166, 212]]] |
|
) |
|
torch.onnx.export( |
|
model, |
|
(input_ids), |
|
f"{folder}/embedding_code_cache.onnx", |
|
verbose=False, |
|
input_names=["input_ids"], |
|
output_names=["input_embed"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class Block(torch.nn.Module): |
|
def __init__(self, layer_id): |
|
super().__init__() |
|
self.layer_id = layer_id |
|
self.layer = layers[layer_id] |
|
self.norm = model_norm |
|
|
|
def forward(self, hidden_states, position_ids, attention_mask): |
|
hidden_states, past_kv = self.layer( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
use_cache=True, |
|
) |
|
present_k, present_v = past_kv |
|
if self.layer_id == NUM_OF_LAYERS - 1: |
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states, present_k, present_v |
|
|
|
def convert_block(layer_id): |
|
model = Block(layer_id) |
|
hidden_states = torch.randn((1, SEQ_LENGTH, HIDDEN_SIZE)) |
|
position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long) |
|
attention_mask = -1000 * torch.ones( |
|
(1, 1, SEQ_LENGTH, SEQ_LENGTH), dtype=torch.float32 |
|
).triu(diagonal=1) |
|
model(hidden_states, position_ids, attention_mask) |
|
torch.onnx.export( |
|
model, |
|
(hidden_states, position_ids, attention_mask), |
|
f"{folder}/block_{layer_id}.onnx", |
|
verbose=False, |
|
input_names=["input_states", "position_ids", "attention_mask"], |
|
output_names=["hidden_states", "past_k", "past_v"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class BlockCache(torch.nn.Module): |
|
|
|
def __init__(self, layer_id): |
|
super().__init__() |
|
self.layer_id = layer_id |
|
self.layer = layers[layer_id] |
|
self.norm = model_norm |
|
|
|
def forward(self, hidden_states, position_ids, attention_mask, past_k, past_v): |
|
hidden_states, past_kv = self.layer( |
|
hidden_states, |
|
attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=(past_k, past_v), |
|
use_cache=True, |
|
) |
|
present_k, present_v = past_kv |
|
if self.layer_id == NUM_OF_LAYERS - 1: |
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states, present_k, present_v |
|
|
|
def convert_block_cache(layer_id): |
|
model = BlockCache(layer_id) |
|
hidden_states = torch.randn((1, 1, HIDDEN_SIZE)) |
|
position_ids = torch.tensor([range(1)], dtype=torch.long) |
|
attention_mask = -1000 * torch.ones( |
|
(1, 1, 1, SEQ_LENGTH + 1), dtype=torch.float32 |
|
).triu(diagonal=1) |
|
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)) |
|
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)) |
|
|
|
torch.onnx.export( |
|
model, |
|
(hidden_states, position_ids, attention_mask, past_k, past_v), |
|
f"{folder}/block_cache_{layer_id}.onnx", |
|
verbose=False, |
|
input_names=[ |
|
"input_states", |
|
"position_ids", |
|
"attention_mask", |
|
"history_k", |
|
"history_v", |
|
], |
|
output_names=["hidden_states", "past_k", "past_v"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class GreedyHead(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, m_logits): |
|
_, token = torch.topk(m_logits.float(), 1) |
|
return token |
|
|
|
def convert_greedy_head_text(): |
|
model = GreedyHead() |
|
m_logits = torch.randn(1, TEXT_VOCAB_SIZE) |
|
|
|
torch.onnx.export( |
|
model, |
|
(m_logits), |
|
f"{folder}/greedy_head_text.onnx", |
|
verbose=False, |
|
input_names=["m_logits"], |
|
output_names=["token"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
def convert_greedy_head_code(): |
|
model = GreedyHead() |
|
m_logits = torch.randn(1, AUDIO_VOCAB_SIZE, gpt_model.num_vq) |
|
|
|
torch.onnx.export( |
|
model, |
|
(m_logits), |
|
f"{folder}/greedy_head_code.onnx", |
|
verbose=False, |
|
input_names=["m_logits"], |
|
output_names=["token"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
class LmHead_infer_text(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, hidden_states): |
|
m_logits = gpt_model.head_text(hidden_states) |
|
return m_logits |
|
|
|
class LmHead_infer_code(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, hidden_states): |
|
m_logits = torch.stack( |
|
[ |
|
gpt_model.head_code[i](hidden_states) |
|
for i in range(gpt_model.num_vq) |
|
], |
|
2, |
|
) |
|
return m_logits |
|
|
|
def convert_lm_head_text(): |
|
model = LmHead_infer_text() |
|
input = torch.randn(1, HIDDEN_SIZE) |
|
|
|
torch.onnx.export( |
|
model, |
|
(input), |
|
f"{folder}/lm_head_text.onnx", |
|
verbose=False, |
|
input_names=["hidden_states"], |
|
output_names=["m_logits"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
def convert_lm_head_code(): |
|
model = LmHead_infer_code() |
|
input = torch.randn(1, HIDDEN_SIZE) |
|
torch.onnx.export( |
|
model, |
|
(input), |
|
f"{folder}/lm_head_code.onnx", |
|
verbose=False, |
|
input_names=["hidden_states"], |
|
output_names=["m_logits"], |
|
do_constant_folding=True, |
|
opset_version=15, |
|
) |
|
|
|
|
|
print(f"Convert block & block_cache") |
|
for i in tqdm(range(NUM_OF_LAYERS)): |
|
convert_block(i) |
|
convert_block_cache(i) |
|
|
|
print(f"Convert embedding") |
|
convert_embedding_text() |
|
convert_embedding_code() |
|
convert_embedding_code_cache() |
|
|
|
print(f"Convert lm_head") |
|
convert_lm_head_code() |
|
convert_lm_head_text() |
|
|
|
print(f"Convert greedy_head") |
|
convert_greedy_head_text() |
|
convert_greedy_head_code() |
|
|
|
|
|
def export_decoder(): |
|
decoder = DVAE( |
|
decoder_config=asdict(chattts_config.decoder), |
|
dim=chattts_config.decoder.idim, |
|
).eval() |
|
decoder.load_state_dict( |
|
torch.load( |
|
asdict(chattts_config.path)["decoder_ckpt_path"], |
|
weights_only=True, |
|
mmap=True, |
|
) |
|
) |
|
|
|
for param in decoder.parameters(): |
|
param.requires_grad = False |
|
rand_input = torch.rand([1, 768, 1024], requires_grad=False) |
|
|
|
def mydec(_inp): |
|
return decoder(_inp, mode="decode") |
|
|
|
jitmodel = jit.trace(mydec, [rand_input]) |
|
jit.save(jitmodel, f"{args.out_dir}/decoder_jit.pt") |
|
|
|
|
|
def export_vocos(): |
|
feature_extractor = instantiate_class( |
|
args=(), init=asdict(chattts_config.vocos.feature_extractor) |
|
) |
|
backbone = instantiate_class(args=(), init=asdict(chattts_config.vocos.backbone)) |
|
head = instantiate_class(args=(), init=asdict(chattts_config.vocos.head)) |
|
vocos = Vocos( |
|
feature_extractor=feature_extractor, backbone=backbone, head=head |
|
).eval() |
|
vocos.load_state_dict( |
|
torch.load( |
|
asdict(chattts_config.path)["vocos_ckpt_path"], weights_only=True, mmap=True |
|
) |
|
) |
|
|
|
for param in vocos.parameters(): |
|
param.requires_grad = False |
|
rand_input = torch.rand([1, 100, 2048], requires_grad=False) |
|
|
|
def myvocos(_inp): |
|
|
|
|
|
x = vocos.backbone(_inp) |
|
x = vocos.head.out(x).transpose(1, 2) |
|
mag, p = x.chunk(2, dim=1) |
|
mag = torch.exp(mag) |
|
mag = torch.clip( |
|
mag, max=1e2 |
|
) |
|
|
|
x = torch.cos(p) |
|
y = torch.sin(p) |
|
return mag, x, y |
|
|
|
jitmodel = jit.trace(myvocos, [rand_input]) |
|
torch.onnx.export( |
|
jitmodel, |
|
[rand_input], |
|
f"{args.out_dir}/vocos_1-100-2048.onnx", |
|
opset_version=12, |
|
do_constant_folding=True, |
|
) |
|
|
|
|
|
if args.gpt: |
|
export_gpt() |
|
|
|
if args.decoder: |
|
export_decoder() |
|
|
|
if args.vocos: |
|
export_vocos() |
|
|
|
print("Done. Please check the files in", args.out_dir) |
|
|