Spaces:
Sleeping
Sleeping
# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# Redistribution and use in source and binary forms, with or without modification, are permitted | |
# provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this list of | |
# conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, this | |
# list of conditions and the following disclaimer in the documentation and/or other | |
# materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its contributors | |
# may be used to endorse or promote products derived from this software without | |
# specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR | |
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | |
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR | |
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
import itertools | |
import time | |
from pathlib import Path | |
from typing import Optional, Tuple | |
import torch | |
import torch._dynamo.config | |
import torch._inductor.config | |
import tqdm | |
def device_sync(device): | |
if "cuda" in device: | |
torch.cuda.synchronize() | |
elif "cpu" in device: | |
pass | |
else: | |
print(f"device={device} is not yet suppported") | |
torch._inductor.config.coordinate_descent_tuning = True | |
torch._inductor.config.triton.unique_kernel_names = True | |
# torch._inductor.config.fx_graph_cache = ( | |
# True # Experimental feature to reduce compilation times, will be on by default in future | |
# ) | |
# imports need to happen after setting above flags | |
from fam.llm.fast_model import Transformer | |
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder | |
from fam.quantiser.text.tokenise import TrainedBPETokeniser | |
def multinomial_sample_one_no_sync( | |
probs_sort, | |
): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor): | |
# ref: huggingface/transformers | |
sorted_logits, sorted_indices = torch.sort(logits, descending=False) | |
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | |
# Keep at least min_tokens_to_keep | |
sorted_indices_to_remove[-1:] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) | |
scores = logits.masked_fill(indices_to_remove, -float("Inf")) | |
return scores | |
def logits_to_probs( | |
logits, | |
*, | |
temperature: torch.Tensor, | |
top_p: Optional[torch.Tensor] = None, | |
top_k: Optional[torch.Tensor] = None, | |
): | |
logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature)) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
if top_p is not None: | |
logits = top_p_sample(logits, top_p) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample( | |
logits, | |
guidance_scale: torch.Tensor, | |
temperature: torch.Tensor, | |
top_p: Optional[torch.Tensor] = None, | |
top_k: Optional[torch.Tensor] = None, | |
): | |
# (b, t, vocab_size) | |
logits = logits[:, -1] | |
logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0) | |
logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb | |
probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def prefill( | |
model: Transformer, | |
x: torch.Tensor, | |
spk_emb: torch.Tensor, | |
input_pos: torch.Tensor, | |
**sampling_kwargs, | |
) -> torch.Tensor: | |
# input_pos: [B, S] | |
logits = model(x, spk_emb, input_pos) | |
return sample(logits, **sampling_kwargs)[0] | |
def decode_one_token( | |
model: Transformer, | |
x: torch.Tensor, | |
spk_emb: torch.Tensor, | |
input_pos: torch.Tensor, | |
**sampling_kwargs, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# input_pos: [B, 1] | |
assert input_pos.shape[-1] == 1 | |
logits = model(x, spk_emb, input_pos) | |
return sample(logits, **sampling_kwargs) | |
def decode_n_tokens( | |
model: Transformer, | |
cur_token: torch.Tensor, | |
spk_emb: torch.Tensor, | |
input_pos: torch.Tensor, | |
num_new_tokens: int, | |
callback=lambda _: _, | |
return_probs: bool = False, | |
end_of_audio_token: int = 2048, | |
**sampling_kwargs, | |
): | |
new_tokens, new_probs = [], [] | |
for i in tqdm.tqdm(range(num_new_tokens)): | |
if (cur_token == end_of_audio_token).any(): | |
break | |
with torch.backends.cuda.sdp_kernel( | |
enable_flash=False, enable_mem_efficient=False, enable_math=True | |
): # Actually better for Inductor to codegen attention here | |
next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs) | |
input_pos += 1 | |
new_tokens.append(next_token.clone()) | |
callback(new_tokens[-1]) | |
if return_probs: | |
new_probs.append(next_prob.clone()) | |
cur_token = next_token.view(1, -1).repeat(2, 1) | |
return new_tokens, new_probs | |
def model_forward(model, x, spk_emb, input_pos): | |
return model(x, spk_emb, input_pos) | |
def generate( | |
model: Transformer, | |
prompt: torch.Tensor, | |
spk_emb: torch.Tensor, | |
*, | |
max_new_tokens: Optional[int] = None, | |
callback=lambda x: x, | |
end_of_audio_token: int = 2048, | |
**sampling_kwargs, | |
) -> torch.Tensor: | |
""" | |
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | |
""" | |
# create an empty tensor of the expected final shape and fill in the current tokens | |
T = prompt.size(0) | |
if max_new_tokens is None: | |
max_seq_length = model.config.block_size | |
else: | |
max_seq_length = T + max_new_tokens | |
max_seq_length = min(max_seq_length, model.config.block_size) | |
max_new_tokens = max_seq_length - T | |
if max_new_tokens <= 0: | |
raise ValueError("Prompt is too long to generate more tokens") | |
device, dtype = prompt.device, prompt.dtype | |
seq = torch.clone(prompt) | |
input_pos = torch.arange(0, T, device=device) | |
next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs) | |
seq = torch.cat([seq, next_token.view(1)]) | |
input_pos = torch.tensor([T], device=device, dtype=torch.int) | |
generated_tokens, _ = decode_n_tokens( | |
model, | |
next_token.view(1, -1).repeat(2, 1), | |
spk_emb, | |
input_pos, | |
max_new_tokens - 1, | |
callback=callback, | |
end_of_audio_token=end_of_audio_token, | |
**sampling_kwargs, | |
) | |
seq = torch.cat([seq, torch.cat(generated_tokens)]) | |
return seq | |
def encode_tokens(tokenizer, string, device="cuda"): | |
tokens = tokenizer.encode(string) | |
return torch.tensor(tokens, dtype=torch.int, device=device) | |
def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision): | |
##### MODEL | |
with torch.device("meta"): | |
model = Transformer.from_name("metavoice-1B") | |
# TODO(quantization): enable | |
# if "int8" in str(checkpoint_path): | |
# print("Using int8 weight-only quantization!") | |
# from quantize import WeightOnlyInt8QuantHandler | |
# simple_quantizer = WeightOnlyInt8QuantHandler(model) | |
# model = simple_quantizer.convert_for_runtime() | |
# from quantize import WeightOnlyInt8QuantHandler | |
# if "int4" in str(checkpoint_path): | |
# print("Using int4 quantization!") | |
# path_comps = checkpoint_path.name.split(".") | |
# assert path_comps[-2].startswith("g") | |
# groupsize = int(path_comps[-2][1:]) | |
# from quantize import WeightOnlyInt4QuantHandler | |
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) | |
# model = simple_quantizer.convert_for_runtime() | |
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) | |
state_dict = checkpoint["model"] | |
# convert MetaVoice-1B model weights naming to gptfast naming | |
unwanted_prefix = "_orig_mod." | |
for k, v in list(state_dict.items()): | |
if k.startswith(unwanted_prefix): | |
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) | |
state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight") | |
state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight") | |
state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight") | |
state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight") | |
for k, v in list(state_dict.items()): | |
if k.startswith("transformer.h."): | |
state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k) | |
k = k.replace("transformer.h.", "layers.") | |
if ".attn.c_attn." in k: | |
state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k) | |
k = k.replace(".attn.c_attn.", ".attention.wqkv.") | |
if ".attn.c_proj." in k: | |
state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k) | |
k = k.replace(".attn.c_proj.", ".attention.wo.") | |
if ".mlp.swiglu.w1." in k: | |
state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k) | |
k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.") | |
if ".mlp.swiglu.w3." in k: | |
state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k) | |
k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.") | |
if ".ln_1." in k: | |
state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k) | |
k = k.replace(".ln_1.", ".attention_norm.") | |
if ".ln_2." in k: | |
state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k) | |
k = k.replace(".ln_2.", ".ffn_norm.") | |
if ".mlp.c_proj." in k: | |
state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k) | |
k = k.replace(".mlp.c_proj.", ".feed_forward.w2.") | |
model.load_state_dict(state_dict, assign=True) | |
# simple_quantizer = WeightOnlyInt8QuantHandler(model) | |
# quantized_state_dict = simple_quantizer.create_quantized_state_dict() | |
# model = simple_quantizer.convert_for_runtime() | |
# model.load_state_dict(quantized_state_dict, assign=True) | |
model = model.to(device=device, dtype=precision) | |
###### TOKENIZER | |
tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) | |
tokenizer = TrainedBPETokeniser(**tokenizer_info) | |
###### SPEAKER EMBEDDER | |
# TODO: fix! | |
smodel = SpeakerEncoder( | |
weights_fpath=spk_emb_ckpt_path, | |
device=device, | |
eval=True, | |
verbose=False, | |
) | |
return model.eval(), tokenizer, smodel | |
def build_model( | |
*, | |
precision: torch.dtype, | |
checkpoint_path: Path = Path(""), | |
spk_emb_ckpt_path: Path = Path(""), | |
compile_prefill: bool = False, | |
compile: bool = True, | |
device: str = "cuda", | |
): | |
assert checkpoint_path.is_file(), checkpoint_path | |
print(f"Using device={device}") | |
print("Loading model ...") | |
t0 = time.time() | |
model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision) | |
device_sync(device=device) # MKG | |
print(f"Time to load model: {time.time() - t0:.02f} seconds") | |
torch.manual_seed(1234) | |
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) | |
with torch.device(device): | |
model.setup_spk_cond_mask() | |
model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size) | |
if compile: | |
print("Compiling...Can take up to 2 mins.") | |
global decode_one_token, prefill | |
decode_one_token = torch.compile( | |
decode_one_token, | |
mode="max-autotune", | |
fullgraph=True, | |
) | |
if compile_prefill: | |
prefill = torch.compile( | |
prefill, | |
fullgraph=True, | |
dynamic=True, | |
) | |
encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) | |
spk_emb = torch.randn((1, 256), device=device, dtype=precision) | |
device_sync(device=device) # MKG | |
t0 = time.perf_counter() | |
y = generate( | |
model, | |
encoded, | |
spk_emb, | |
max_new_tokens=200, | |
callback=lambda x: x, | |
temperature=torch.tensor(1.0, device=device, dtype=precision), | |
top_k=None, | |
top_p=torch.tensor(0.95, device=device, dtype=precision), | |
guidance_scale=torch.tensor(3.0, device=device, dtype=precision), | |
end_of_audio_token=9999, # don't end early for compilation stage. | |
) | |
device_sync(device=device) # MKG | |
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") | |
return model, tokenizer, smodel, model_size | |
def main( | |
*, | |
model, | |
tokenizer, | |
model_size, | |
prompt: str, | |
guidance_scale: torch.Tensor, | |
temperature: torch.Tensor, | |
spk_emb: torch.Tensor, | |
top_k: Optional[torch.Tensor] = None, | |
top_p: Optional[torch.Tensor] = None, | |
device: str = "cuda", | |
) -> list: | |
"""Generates text samples based on a pre-trained Transformer model and tokenizer.""" | |
encoded = encode_tokens(tokenizer, prompt, device=device) | |
prompt_length = encoded.size(0) | |
aggregate_metrics: dict = { | |
"tokens_per_sec": [], | |
} | |
device_sync(device=device) # MKG | |
if True: | |
callback = lambda x: x | |
t0 = time.perf_counter() | |
y = generate( | |
model, | |
encoded, | |
spk_emb, | |
callback=callback, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
guidance_scale=guidance_scale, | |
) | |
device_sync(device=device) # MKG | |
t = time.perf_counter() - t0 | |
tokens_generated = y.size(0) - prompt_length | |
tokens_sec = tokens_generated / t | |
aggregate_metrics["tokens_per_sec"].append(tokens_sec) | |
print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") | |
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") | |
# print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") | |
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n") | |
return y.tolist() | |