ChatTTS2 / examples /onnx /exporter.py
zhengr's picture
init
c02bdcd
raw
history blame
13.4 kB
import os, sys
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
from dataclasses import asdict
import argparse
import torch
from tqdm import tqdm
from ChatTTS.model.dvae import DVAE
from ChatTTS.config import Config
from vocos import Vocos
from vocos.pretrained import instantiate_class
import torch.jit as jit
from gpt import GPT
# disable cuda
torch.cuda.is_available = lambda: False
# add args to control which modules to export
parser = argparse.ArgumentParser()
parser.add_argument("--gpt", action="store_true", help="trace gpt")
parser.add_argument("--decoder", action="store_true", help="trace decoder")
parser.add_argument("--vocos", action="store_true", help="trace vocos")
parser.add_argument(
"--pth_dir", default="./assets", type=str, help="path to the pth model directory"
)
parser.add_argument(
"--out_dir", default="./tmp", type=str, help="path to output directory"
)
args = parser.parse_args()
chattts_config = Config()
def export_gpt():
gpt_model = GPT(gpt_config=asdict(chattts_config.gpt), use_flash_attn=False).eval()
gpt_model.from_pretrained(asdict(chattts_config.path)["gpt_ckpt_path"])
gpt_model = gpt_model.eval()
for param in gpt_model.parameters():
param.requires_grad = False
config = gpt_model.gpt.config
layers = gpt_model.gpt.layers
model_norm = gpt_model.gpt.norm
NUM_OF_LAYERS = config.num_hidden_layers
HIDDEN_SIZE = config.hidden_size
NUM_ATTENTION_HEADS = config.num_attention_heads
NUM_KEY_VALUE_HEADS = config.num_key_value_heads
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS # 64
TEXT_VOCAB_SIZE = gpt_model.emb_text.weight.shape[0]
AUDIO_VOCAB_SIZE = gpt_model.emb_code[0].weight.shape[0]
SEQ_LENGTH = 512
folder = os.path.join(args.out_dir, "gpt")
os.makedirs(folder, exist_ok=True)
for param in gpt_model.emb_text.parameters():
param.requires_grad = False
for param in gpt_model.emb_code.parameters():
param.requires_grad = False
for param in gpt_model.head_code.parameters():
param.requires_grad = False
for param in gpt_model.head_text.parameters():
param.requires_grad = False
class EmbeddingText(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, input_ids):
return gpt_model.emb_text(input_ids)
def convert_embedding_text():
model = EmbeddingText()
input_ids = torch.tensor([range(SEQ_LENGTH)])
torch.onnx.export(
model,
(input_ids),
f"{folder}/embedding_text.onnx",
verbose=False,
input_names=["input_ids"],
output_names=["input_embed"],
do_constant_folding=True,
opset_version=15,
)
class EmbeddingCode(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, input_ids):
input_ids = input_ids.unsqueeze(2).expand(
-1, -1, gpt_model.num_vq
) # for forward_first_code
code_emb = [
gpt_model.emb_code[i](input_ids[:, :, i])
for i in range(gpt_model.num_vq)
]
return torch.stack(code_emb, 2).sum(2)
def convert_embedding_code():
model = EmbeddingCode()
input_ids = torch.tensor([range(SEQ_LENGTH)])
torch.onnx.export(
model,
(input_ids),
f"{folder}/embedding_code.onnx",
verbose=False,
input_names=["input_ids"],
output_names=["input_embed"],
do_constant_folding=True,
opset_version=15,
)
class EmbeddingCodeCache(torch.nn.Module): # for forward_next_code
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, input_ids):
code_emb = [
gpt_model.emb_code[i](input_ids[:, :, i])
for i in range(gpt_model.num_vq)
]
return torch.stack(code_emb, 2).sum(2)
def convert_embedding_code_cache():
model = EmbeddingCodeCache()
input_ids = torch.tensor(
[[[416, 290, 166, 212]]]
) # torch.tensor([[range(gpt_model.num_vq)]])
torch.onnx.export(
model,
(input_ids),
f"{folder}/embedding_code_cache.onnx",
verbose=False,
input_names=["input_ids"],
output_names=["input_embed"],
do_constant_folding=True,
opset_version=15,
)
class Block(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id] # LlamaDecoderLayer
self.norm = model_norm
def forward(self, hidden_states, position_ids, attention_mask):
hidden_states, past_kv = self.layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
present_k, present_v = past_kv
if self.layer_id == NUM_OF_LAYERS - 1:
hidden_states = self.norm(hidden_states)
return hidden_states, present_k, present_v
def convert_block(layer_id):
model = Block(layer_id)
hidden_states = torch.randn((1, SEQ_LENGTH, HIDDEN_SIZE))
position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long)
attention_mask = -1000 * torch.ones(
(1, 1, SEQ_LENGTH, SEQ_LENGTH), dtype=torch.float32
).triu(diagonal=1)
model(hidden_states, position_ids, attention_mask)
torch.onnx.export(
model,
(hidden_states, position_ids, attention_mask),
f"{folder}/block_{layer_id}.onnx",
verbose=False,
input_names=["input_states", "position_ids", "attention_mask"],
output_names=["hidden_states", "past_k", "past_v"],
do_constant_folding=True,
opset_version=15,
)
class BlockCache(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.layer = layers[layer_id]
self.norm = model_norm
def forward(self, hidden_states, position_ids, attention_mask, past_k, past_v):
hidden_states, past_kv = self.layer(
hidden_states,
attention_mask,
position_ids=position_ids,
past_key_value=(past_k, past_v),
use_cache=True,
)
present_k, present_v = past_kv
if self.layer_id == NUM_OF_LAYERS - 1:
hidden_states = self.norm(hidden_states)
return hidden_states, present_k, present_v
def convert_block_cache(layer_id):
model = BlockCache(layer_id)
hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
position_ids = torch.tensor([range(1)], dtype=torch.long)
attention_mask = -1000 * torch.ones(
(1, 1, 1, SEQ_LENGTH + 1), dtype=torch.float32
).triu(diagonal=1)
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
torch.onnx.export(
model,
(hidden_states, position_ids, attention_mask, past_k, past_v),
f"{folder}/block_cache_{layer_id}.onnx",
verbose=False,
input_names=[
"input_states",
"position_ids",
"attention_mask",
"history_k",
"history_v",
],
output_names=["hidden_states", "past_k", "past_v"],
do_constant_folding=True,
opset_version=15,
)
class GreedyHead(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, m_logits):
_, token = torch.topk(m_logits.float(), 1)
return token
def convert_greedy_head_text():
model = GreedyHead()
m_logits = torch.randn(1, TEXT_VOCAB_SIZE)
torch.onnx.export(
model,
(m_logits),
f"{folder}/greedy_head_text.onnx",
verbose=False,
input_names=["m_logits"],
output_names=["token"],
do_constant_folding=True,
opset_version=15,
)
def convert_greedy_head_code():
model = GreedyHead()
m_logits = torch.randn(1, AUDIO_VOCAB_SIZE, gpt_model.num_vq)
torch.onnx.export(
model,
(m_logits),
f"{folder}/greedy_head_code.onnx",
verbose=False,
input_names=["m_logits"],
output_names=["token"],
do_constant_folding=True,
opset_version=15,
)
class LmHead_infer_text(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states):
m_logits = gpt_model.head_text(hidden_states)
return m_logits
class LmHead_infer_code(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states):
m_logits = torch.stack(
[
gpt_model.head_code[i](hidden_states)
for i in range(gpt_model.num_vq)
],
2,
)
return m_logits
def convert_lm_head_text():
model = LmHead_infer_text()
input = torch.randn(1, HIDDEN_SIZE)
torch.onnx.export(
model,
(input),
f"{folder}/lm_head_text.onnx",
verbose=False,
input_names=["hidden_states"],
output_names=["m_logits"],
do_constant_folding=True,
opset_version=15,
)
def convert_lm_head_code():
model = LmHead_infer_code()
input = torch.randn(1, HIDDEN_SIZE)
torch.onnx.export(
model,
(input),
f"{folder}/lm_head_code.onnx",
verbose=False,
input_names=["hidden_states"],
output_names=["m_logits"],
do_constant_folding=True,
opset_version=15,
)
# export models
print(f"Convert block & block_cache")
for i in tqdm(range(NUM_OF_LAYERS)):
convert_block(i)
convert_block_cache(i)
print(f"Convert embedding")
convert_embedding_text()
convert_embedding_code()
convert_embedding_code_cache()
print(f"Convert lm_head")
convert_lm_head_code()
convert_lm_head_text()
print(f"Convert greedy_head")
convert_greedy_head_text()
convert_greedy_head_code()
def export_decoder():
decoder = DVAE(
decoder_config=asdict(chattts_config.decoder),
dim=chattts_config.decoder.idim,
).eval()
decoder.load_state_dict(
torch.load(
asdict(chattts_config.path)["decoder_ckpt_path"],
weights_only=True,
mmap=True,
)
)
for param in decoder.parameters():
param.requires_grad = False
rand_input = torch.rand([1, 768, 1024], requires_grad=False)
def mydec(_inp):
return decoder(_inp, mode="decode")
jitmodel = jit.trace(mydec, [rand_input])
jit.save(jitmodel, f"{args.out_dir}/decoder_jit.pt")
def export_vocos():
feature_extractor = instantiate_class(
args=(), init=asdict(chattts_config.vocos.feature_extractor)
)
backbone = instantiate_class(args=(), init=asdict(chattts_config.vocos.backbone))
head = instantiate_class(args=(), init=asdict(chattts_config.vocos.head))
vocos = Vocos(
feature_extractor=feature_extractor, backbone=backbone, head=head
).eval()
vocos.load_state_dict(
torch.load(
asdict(chattts_config.path)["vocos_ckpt_path"], weights_only=True, mmap=True
)
)
for param in vocos.parameters():
param.requires_grad = False
rand_input = torch.rand([1, 100, 2048], requires_grad=False)
def myvocos(_inp):
# return chat.vocos.decode(_inp) # TPU cannot support the istft OP, thus it has to be moved to postprocessing
# reference: https://github.com/gemelo-ai/vocos.git
x = vocos.backbone(_inp)
x = vocos.head.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(
mag, max=1e2
) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
return mag, x, y
jitmodel = jit.trace(myvocos, [rand_input])
torch.onnx.export(
jitmodel,
[rand_input],
f"{args.out_dir}/vocos_1-100-2048.onnx",
opset_version=12,
do_constant_folding=True,
)
if args.gpt:
export_gpt()
if args.decoder:
export_decoder()
if args.vocos:
export_vocos()
print("Done. Please check the files in", args.out_dir)