Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L84 | |
def precompute_freqs_cis(dim: int, | |
end: int, | |
theta: float = 10000.0) -> torch.Tensor: | |
"""Precomputes the frequency cis.""" | |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
t = torch.arange(end, device=freqs.device) | |
freqs = torch.outer(t, freqs).float() | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
return freqs_cis | |
# modified from: | |
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95 | |
def google_apply_rotary_emb(x: torch.Tensor, | |
freqs_cis: torch.Tensor) -> torch.Tensor: | |
"""Applies the rotary embedding to the query and key tensors.""" | |
x_ = torch.view_as_complex( | |
torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1)) | |
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) | |
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) | |
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1) | |
return x_out | |
def llama_apply_rotary_emb(x: torch.Tensor, | |
freqs_cis: torch.Tensor) -> torch.Tensor: | |
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) | |
return x_out.type_as(x) | |
WENET_APPLY_ROTARY_EMB = { | |
'google': google_apply_rotary_emb, | |
'llama': llama_apply_rotary_emb, | |
} | |