Dia-1.6B / dia /model.py
buttercrab's picture
update to faster inference
4aa0f34
import time
from enum import Enum
import dac
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from .audio import (
apply_audio_delay,
build_delay_indices,
build_revert_indices,
decode,
revert_audio_delay,
)
from .config import DiaConfig
from .layers import DiaModel
from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
DEFAULT_SAMPLE_RATE = 44100
def _get_default_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _sample_next_token(
logits_BCxV: torch.Tensor,
temperature: float,
top_p: float,
cfg_filter_top_k: int | None = None,
) -> torch.Tensor:
if temperature == 0.0:
return torch.argmax(logits_BCxV, dim=-1)
logits_BCxV = logits_BCxV / temperature
if cfg_filter_top_k is not None:
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
if top_p < 1.0:
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
probs_BCxV, dim=-1, descending=True
)
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
..., :-1
].clone()
sorted_indices_to_remove_BCxV[..., 0] = 0
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
indices_to_remove_BCxV.scatter_(
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
)
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
sampled_indices_C = sampled_indices_BC.squeeze(-1)
return sampled_indices_C
class ComputeDtype(str, Enum):
FLOAT32 = "float32"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
def to_dtype(self) -> torch.dtype:
if self == ComputeDtype.FLOAT32:
return torch.float32
elif self == ComputeDtype.FLOAT16:
return torch.float16
elif self == ComputeDtype.BFLOAT16:
return torch.bfloat16
else:
raise ValueError(f"Unsupported compute dtype: {self}")
class Dia:
def __init__(
self,
config: DiaConfig,
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
device: torch.device | None = None,
):
"""Initializes the Dia model.
Args:
config: The configuration object for the model.
device: The device to load the model onto. If None, will automatically select the best available device.
Raises:
RuntimeError: If there is an error loading the DAC model.
"""
super().__init__()
self.config = config
self.device = device if device is not None else _get_default_device()
if isinstance(compute_dtype, str):
compute_dtype = ComputeDtype(compute_dtype)
self.compute_dtype = compute_dtype.to_dtype()
self.model = DiaModel(config, self.compute_dtype)
self.dac_model = None
@classmethod
def from_local(
cls,
config_path: str,
checkpoint_path: str,
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
device: torch.device | None = None,
) -> "Dia":
"""Loads the Dia model from local configuration and checkpoint files.
Args:
config_path: Path to the configuration JSON file.
checkpoint_path: Path to the model checkpoint (.pth) file.
device: The device to load the model onto. If None, will automatically select the best available device.
Returns:
An instance of the Dia model loaded with weights and set to eval mode.
Raises:
FileNotFoundError: If the config or checkpoint file is not found.
RuntimeError: If there is an error loading the checkpoint.
"""
config = DiaConfig.load(config_path)
if config is None:
raise FileNotFoundError(f"Config file not found at {config_path}")
dia = cls(config, compute_dtype, device)
try:
state_dict = torch.load(checkpoint_path, map_location=dia.device)
dia.model.load_state_dict(state_dict)
except FileNotFoundError:
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
except Exception as e:
raise RuntimeError(
f"Error loading checkpoint from {checkpoint_path}"
) from e
dia.model.to(dia.device)
dia.model.eval()
dia._load_dac_model()
return dia
@classmethod
def from_pretrained(
cls,
model_name: str = "nari-labs/Dia-1.6B",
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
device: torch.device | None = None,
) -> "Dia":
"""Loads the Dia model from a Hugging Face Hub repository.
Downloads the configuration and checkpoint files from the specified
repository ID and then loads the model.
Args:
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
device: The device to load the model onto. If None, will automatically select the best available device.
Returns:
An instance of the Dia model loaded with weights and set to eval mode.
Raises:
FileNotFoundError: If config or checkpoint download/loading fails.
RuntimeError: If there is an error loading the checkpoint.
"""
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
def _load_dac_model(self):
try:
dac_model_path = dac.utils.download()
dac_model = dac.DAC.load(dac_model_path).to(self.device)
except Exception as e:
raise RuntimeError("Failed to load DAC model") from e
self.dac_model = dac_model
def _prepare_text_input(self, text: str) -> torch.Tensor:
"""Encodes text prompt, pads, and creates attention mask and positions."""
text_pad_value = self.config.data.text_pad_value
max_len = self.config.data.text_length
byte_text = text.encode("utf-8")
replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
text_tokens = list(replaced_bytes)
current_len = len(text_tokens)
padding_needed = max_len - current_len
if padding_needed <= 0:
text_tokens = text_tokens[:max_len]
padded_text_np = np.array(text_tokens, dtype=np.uint8)
else:
padded_text_np = np.pad(
text_tokens,
(0, padding_needed),
mode="constant",
constant_values=text_pad_value,
).astype(np.uint8)
src_tokens = (
torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
) # [1, S]
return src_tokens
def _prepare_audio_prompt(
self, audio_prompt: torch.Tensor | None
) -> tuple[torch.Tensor, int]:
num_channels = self.config.data.channels
audio_bos_value = self.config.data.audio_bos_value
audio_pad_value = self.config.data.audio_pad_value
delay_pattern = self.config.data.delay_pattern
max_delay_pattern = max(delay_pattern)
prefill = torch.full(
(1, num_channels),
fill_value=audio_bos_value,
dtype=torch.int,
device=self.device,
)
prefill_step = 1
if audio_prompt is not None:
prefill_step += audio_prompt.shape[0]
prefill = torch.cat([prefill, audio_prompt], dim=0)
delay_pad_tensor = torch.full(
(max_delay_pattern, num_channels),
fill_value=-1,
dtype=torch.int,
device=self.device,
)
prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
delay_precomp = build_delay_indices(
B=1,
T=prefill.shape[0],
C=num_channels,
delay_pattern=delay_pattern,
)
prefill = apply_audio_delay(
audio_BxTxC=prefill.unsqueeze(0),
pad_value=audio_pad_value,
bos_value=audio_bos_value,
precomp=delay_precomp,
).squeeze(0)
return prefill, prefill_step
def _prepare_generation(
self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
):
enc_input_cond = self._prepare_text_input(text)
enc_input_uncond = torch.zeros_like(enc_input_cond)
enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
if isinstance(audio_prompt, str):
audio_prompt = self.load_audio(audio_prompt)
prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
if verbose:
print("generate: data loaded")
enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
encoder_out = self.model.encoder(enc_input, enc_state)
dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
encoder_out, enc_state.positions
)
dec_state = DecoderInferenceState.new(
self.config,
enc_state,
encoder_out,
dec_cross_attn_cache,
self.compute_dtype,
)
dec_output = DecoderOutput.new(self.config, self.device)
dec_output.prefill(prefill, prefill_step)
dec_step = prefill_step - 1
if dec_step > 0:
dec_state.prepare_step(0, dec_step)
tokens_BxTxC = (
dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
)
self.model.decoder.forward(tokens_BxTxC, dec_state)
return dec_state, dec_output
def _decoder_step(
self,
tokens_Bx1xC: torch.Tensor,
dec_state: DecoderInferenceState,
cfg_scale: float,
temperature: float,
top_p: float,
cfg_filter_top_k: int,
) -> torch.Tensor:
audio_eos_value = self.config.data.audio_eos_value
logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
uncond_logits_CxV = logits_last_BxCxV[0, :, :]
cond_logits_CxV = logits_last_BxCxV[1, :, :]
logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
logits_CxV[1:, audio_eos_value:] = -torch.inf
pred_C = _sample_next_token(
logits_CxV.float(),
temperature=temperature,
top_p=top_p,
cfg_filter_top_k=cfg_filter_top_k,
)
return pred_C
def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
num_channels = self.config.data.channels
seq_length = generated_codes.shape[0]
delay_pattern = self.config.data.delay_pattern
audio_pad_value = self.config.data.audio_pad_value
max_delay_pattern = max(delay_pattern)
revert_precomp = build_revert_indices(
B=1,
T=seq_length,
C=num_channels,
delay_pattern=delay_pattern,
)
codebook = revert_audio_delay(
audio_BxTxC=generated_codes.unsqueeze(0),
pad_value=audio_pad_value,
precomp=revert_precomp,
T=seq_length,
)[:, :-max_delay_pattern, :]
min_valid_index = 0
max_valid_index = 1023
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
codebook[invalid_mask] = 0
audio = decode(self.dac_model, codebook.transpose(1, 2))
return audio.squeeze().cpu().numpy()
def load_audio(self, audio_path: str) -> torch.Tensor:
audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
if sr != DEFAULT_SAMPLE_RATE:
audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
audio = audio.to(self.device).unsqueeze(0) # 1, C, T
audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
_, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
return encoded_frame.squeeze(0).transpose(0, 1)
def save_audio(self, path: str, audio: np.ndarray):
import soundfile as sf
sf.write(path, audio, DEFAULT_SAMPLE_RATE)
@torch.inference_mode()
def generate(
self,
text: str,
max_tokens: int | None = None,
cfg_scale: float = 3.0,
temperature: float = 1.3,
top_p: float = 0.95,
use_torch_compile: bool = False,
cfg_filter_top_k: int = 35,
audio_prompt: str | torch.Tensor | None = None,
audio_prompt_path: str | None = None,
use_cfg_filter: bool | None = None,
verbose: bool = False,
) -> np.ndarray:
audio_eos_value = self.config.data.audio_eos_value
audio_pad_value = self.config.data.audio_pad_value
delay_pattern = self.config.data.delay_pattern
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
max_delay_pattern = max(delay_pattern)
self.model.eval()
if audio_prompt_path:
print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
audio_prompt = audio_prompt_path
if use_cfg_filter is not None:
print("Warning: use_cfg_filter is deprecated.")
if verbose:
total_start_time = time.time()
dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
dec_step = dec_output.prefill_step - 1
bos_countdown = max_delay_pattern
eos_detected = False
eos_countdown = -1
if use_torch_compile:
step_fn = torch.compile(self._decoder_step, mode="default")
else:
step_fn = self._decoder_step
if verbose:
print("generate: starting generation loop")
if use_torch_compile:
print(
"generate: by using use_torch_compile=True, the first step would take long"
)
start_time = time.time()
while dec_step < max_tokens:
dec_state.prepare_step(dec_step)
tokens_Bx1xC = (
dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
)
pred_C = step_fn(
tokens_Bx1xC,
dec_state,
cfg_scale,
temperature,
top_p,
cfg_filter_top_k,
)
if (
not eos_detected and pred_C[0] == audio_eos_value
) or dec_step == max_tokens - max_delay_pattern - 1:
eos_detected = True
eos_countdown = max_delay_pattern
if eos_countdown > 0:
step_after_eos = max_delay_pattern - eos_countdown
for i, d in enumerate(delay_pattern):
if step_after_eos == d:
pred_C[i] = audio_eos_value
elif step_after_eos > d:
pred_C[i] = audio_pad_value
eos_countdown -= 1
bos_countdown = max(0, bos_countdown - 1)
dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
if eos_countdown == 0:
break
dec_step += 1
if verbose and dec_step % 86 == 0:
duration = time.time() - start_time
print(
f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
)
start_time = time.time()
if dec_output.prefill_step >= dec_step + 1:
print("Warning: Nothing generated")
return None
generated_codes = dec_output.generated_tokens[
dec_output.prefill_step : dec_step + 1, :
]
if verbose:
total_step = dec_step + 1 - dec_output.prefill_step
total_duration = time.time() - total_start_time
print(
f"generate: total step={total_step}, total duration={total_duration:.3f}s"
)
return self._generate_output(generated_codes)