File size: 3,104 Bytes
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Optional

import torch
import torch.nn.functional as F


def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
    curr_pos_id: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Applies rotary positional embeddings to the input tensor.
    Args:
        x (torch.Tensor): The input tensor.
        freqs_cis (torch.Tensor): A tensor containing the precomputed rotary
            frequency components.
        curr_pos_id (Optional[torch.Tensor]): An optional tensor specifying the
            current position IDs to use for selecting a subset of `freqs_cis`.
            If None, the function uses the last `seq_len` positions.
    Returns:
        torch.Tensor: The input tensor `x` with rotary positional embeddings
        applied.
    """
    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    if curr_pos_id is None:
        freqs_cis = freqs_cis[:, -x.shape[2] :].unsqueeze(1)
    else:
        freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
    y = torch.view_as_real(x_ * freqs_cis).flatten(3)
    return y.type_as(x)


@torch.no_grad
def precompute_freqs_cis(dim: int, t: torch.Tensor, theta: float = 10000.0):
    """Calculate rotary embedding cos & sin, this is useful when every blocks in the network use same positional embedding.

    Args:
        dim (int): dimension of the single head of the transformer block
        t (torch.Tensor): position ids [..., L]
        theta (int, optional): rope theta. Defaults to 10000.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: tuple of cos and sin of rope
    """
    assert dim % 2 == 0, (
        "RoPE only supports embedding dimensions that are multiples of 2"
    )
    freqs = 1.0 / (
        theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)
    )
    # [batch_size, seq_len, num_freqs]
    freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

    return freqs_cis


def scaled_dot_product_attention_with_rotary_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    freqs_cis: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    curr_pos_id: Optional[torch.Tensor] = None,
    is_causal: bool = False,
) -> torch.Tensor:
    """
    Computes scaled dot product attention on query, key and value tensors
    with rotary position embeddings on query and key.

    Without caching enabled,
        q should be (bs, nh, seqlen, hd).
        k and v should stay unchanged, (bs, nh, seqlen, hd).
    With caching enabled,
        q should be (bs, nh, 1, hd).
        k and v should stay unchanged, (bs, nh, 1, hd).
        causal_mask must be False.
    """
    q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)  # (bs, nh, l, hd)
    k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)  # (bs, nh, s + l, hd)

    x = F.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=0.0,
        is_causal=is_causal and attn_mask is None,
    )
    return x