OSUM / wenet /utils /rope_utils.py
tomxxie
适配zeroGPU
568e264
raw
history blame
1.5 kB
import torch
# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L84
def precompute_freqs_cis(dim: int,
end: int,
theta: float = 10000.0) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
# modified from:
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95
def google_apply_rotary_emb(x: torch.Tensor,
freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1)
return x_out
def llama_apply_rotary_emb(x: torch.Tensor,
freqs_cis: torch.Tensor) -> torch.Tensor:
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return x_out.type_as(x)
WENET_APPLY_ROTARY_EMB = {
'google': google_apply_rotary_emb,
'llama': llama_apply_rotary_emb,
}