File size: 6,069 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) 2023, Tri Dao.

import math

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import (
    apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding


# NeoX-style rotary embedding
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset):
    device = "cuda"
    dtype = torch.float16
    rtol, atol = (1e-3, 5e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen_total = 2048
    seqlen = seqlen_total - seqlen_offset
    nheads = 16
    headdim = 128
    rotary_dim = int(headdim * rotary_emb_fraction)
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
    )
    qkv_og = qkv.clone().detach()  # Our implementation modifies qkv inplace
    rotary = RotaryEmbedding(rotary_dim, device=device)
    rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
    # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
    cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
    cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
    q_pt = (
        rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
        .detach()
        .clone()
        .requires_grad_(True)
    )
    k_pt = (
        rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
        .detach()
        .clone()
        .requires_grad_(True)
    )
    q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
    out = rotary(qkv, seqlen_offset=seqlen_offset)
    assert torch.allclose(
        rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
    )
    assert torch.allclose(
        rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
    )
    assert torch.allclose(
        rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
    )
    assert torch.allclose(
        rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
    )
    assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
    assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])

    g = torch.randn_like(out)
    g_og = g.clone().detach()  # Our implementation modifies g inplace
    out.backward(g)
    q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
    k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
    assert torch.allclose(
        rearrange(q_pt.grad, "b h s d -> b s h d"),
        qkv.grad[:, :, 0, :, :rotary_dim],
        rtol=rtol,
        atol=atol,
    )
    assert torch.allclose(
        rearrange(k_pt.grad, "b h s d -> b s h d"),
        qkv.grad[:, :, 1, :, :rotary_dim],
        rtol=rtol,
        atol=atol,
    )
    assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
    assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])


# GPT-J-style rotary embedding
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
    device = "cuda"
    dtype = torch.float16
    rtol, atol = (1e-3, 5e-3)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen_total = 2048
    seqlen = seqlen_total - seqlen_offset
    nheads = 16
    headdim = 128
    rotary_dim = int(headdim * rotary_emb_fraction)
    qkv = torch.randn(
        batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
    )
    qkv_og = qkv.clone().detach()  # Our implementation modifies qkv inplace
    rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
    sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
    sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)
    q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)
    k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)
    q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)
    k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)
    out = rotary(qkv, seqlen_offset=seqlen_offset)
    assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)
    assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)
    assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
    assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
    assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
    assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])

    g = torch.randn_like(out)
    g_og = g.clone().detach()  # Our implementation modifies g inplace
    out.backward(g)
    q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])
    k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])
    assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
    assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
    assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
    assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])