import torch import torch.nn.functional as F import torchaudio import copy from torch import Tensor, nn import logging from .model import length_to_mask from .samplers import (apply_typical_p, early_eos_penalty, top_k_top_p_filtering, freq_rep_penalty) from .nn_future import RotatingBufferCache from .minbpe.codebook import CodebookTokenizer from .minbpe.regex import RegexTokenizer @torch.inference_mode() def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer, codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int, max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None, top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100, typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True, beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]: """ Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1 indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes). Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep track of top `beam_width` outcomes, selecting the top one among them. - Optionally vocode if `vocode` (default True). - See `InferenceConfig` for other inference docs. """ assert xx.dim() == 1, "Only batch size of 1 is currently supported." assert beam_width == 1, "Only beam size of 1 is currently supported." # internally our batch size will be the beam width bs = beam_width x_inp = xx[None].repeat(bs, 1) # (bs, seq_len) ss_gen = ss_gen[None].repeat(bs, 1, 1) # We must subtract 1 in the line below so that we match the train-time conditions of having a # False padding value for the token position. This is needed so that we correctly use the # _acoustic_ and not the linguistic language embedding for the token. offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device) valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000'] valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000'] # Make mask that is True where we have valid outputs, False otherwise (where we have text outputs). # logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device) # logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True # logit_mask[vocab['s2i']['']] = True cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device) eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>'] n_vocab = len(texttok.vocab) + len(speechtok.vocab) logging.info(f"Starting beam decoding with beam_width={beam_width}") prev_ids = [[] for _ in range(bs)] cache = None if use_kv_cache: # Initialise kv cache cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len) cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim) cache.to(device=x_inp.device, dtype=torch.float16) counter = 0 while x_inp.shape[-1] < max_len: counter += 1 gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device) padding_mask = length_to_mask(gen_length, offsets) with torch.autocast('cuda', enabled=fp16): logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter) logits = logits.float() logits = logits[:, -1] # select last index, now (bs, logit_dim) # <---------------------- logit filtering ----------------------> filtered_logits = logits.clone() # apply repetition penalty before logit mask if any item in the beam has more than 1 prior token. if len(prev_ids[0]) > 1: filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long), alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, penalty_window=penalty_window) filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') filtered_logits[..., valid_logit_idx_end:] = float('-inf') if n_phones_gen is not None: # apply eos penalty filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen, eos_penalty_decay, eos_penalty_factor, eos_index=eos_idx) filtered_logits = filtered_logits / temperature filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p) filtered_logits = apply_typical_p(filtered_logits, mass=typical_p) # mask out anything that isn't first quantizer output codes filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') filtered_logits[..., valid_logit_idx_end:] = float('-inf') logits = filtered_logits # <---------------------- next frame prediction ---------------------> logprobs = logits.log_softmax(dim=-1) # update assignments: if any beam ended in last step, it MUST also end in this step. # so, below we multiply the logits with a True/False mask, setting to for j in range(bs): if x_inp[j, -1] == eos_idx: # do not add any additional probability to it, keeping it the same for all vocab idxs logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos logprobs[j, eos_idx] = 0 # probability=1 of after candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab) logp_flat = logprobs.flatten() candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,) # Ravel it up: beam_idxs = candidates // n_vocab # (bs,) tok_inds_in_each_beam = candidates % n_vocab # (bs,) # check for breaks if torch.all(tok_inds_in_each_beam == eos_idx): # apply length penalty: non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks gen_length = non_eos_toks - first_codex_idx penalties = (gen_length**beam_length_penalty) penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None] eos_avg_logps = penalized_cum_tok_logp[:, eos_idx] best_beam_idx = eos_avg_logps.argmax() best_avg_logp = eos_avg_logps[best_beam_idx] best_beam = x_inp[best_beam_idx] logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. " f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | " f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}")) break # update beam histories: x_inp = x_inp[beam_idxs] # update next token next_sample = tok_inds_in_each_beam # update cum logprob cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam] # update prior inds to point to correct beam prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs] # add new tokens to previous ids for j in range(bs): prev_ids[j].append(tok_inds_in_each_beam[j].item()) logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu()) # update cache with beam indexes if cache is not None: cache.cache_k = cache.cache_k[:, beam_idxs] cache.cache_v = cache.cache_v[:, beam_idxs] # add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim. x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1) if x_inp.shape[-1] >= max_len - 1: logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!") best_beam = x_inp[0] if not vocode: return best_beam # (seq_len,) else: raise AssertionError()