emo-knob / fam /llm /mixins /non_causal.py
tonychenxyz's picture
init
9e34a62
raw
history blame
2.92 kB
from typing import Optional
import torch
from torch.nn import functional as F
class NonCausalInferenceMixin:
"""
Mixin class for non-causal inference in a language model.
This class provides methods for performing non-causal sampling using a language model.
"""
@torch.no_grad()
def _non_causal_sample(
self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
):
"""
Perform non-causal sampling.
Args:
idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
temperature (float): Temperature parameter for scaling the logits.
top_k (int): Number of top options to consider.
Returns:
torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
"""
b, c, t = idx.size()
assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
# forward the model to get the logits for the index in the sequence
list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
# scale by desired temperature
list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
# optionally crop the logits to only the top k options
if top_k is not None:
for i in range(len(list_logits)):
logits = list_logits[i] # (b, t, vocab_size)
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
logits[logits < v[:, :, [-1]]] = -float("Inf")
list_logits[i] = logits # (b, t, vocab_size)
assert logits.shape[0] == b and logits.shape[1] == t
# apply softmax to convert logits to (normalized) probabilities
# TODO: check shapes here!
probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
assert probs[0].shape[0] == b and probs[0].shape[1] == t
# TODO: output shape is as expected
outs = []
for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
out = [
torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
out = torch.cat(out, dim=0) # (b, 1, t)
assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
outs.append(out)
out = torch.cat(outs, dim=1) # (b, c, t)
assert out.shape[0] == b and out.shape[2] == t
return out