File size: 4,143 Bytes
d4607d7
 
 
 
 
 
 
 
 
 
 
 
 
cbf7ffd
d4607d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# import torch
import math
import jax
import jax.numpy as jnp
from einops import rearrange
from flax import nnx

Tensor=jax.Array

def check_tpu():
    return any('TPU' in d.device_kind for d in jax.devices())

# from torch import Tensor
if False:
    from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
    # q,  # [batch_size, num_heads, q_seq_len, d_model]
    # k,  # [batch_size, num_heads, kv_seq_len, d_model]
    # v,  # [batch_size, num_heads, kv_seq_len, d_model]
    def flash_mha(q, k, v):
        return flash_attention(q, k, v, sm_scale=1/math.sqrt(q.shape[-1]))
else:
    from jax.experimental.pallas.ops.gpu.attention import mha, mha_reference
    def pallas_mha(q, k, v):
        # B L H D
        # return mha_reference(q,k,v,segment_ids=None,sm_scale=1/math.sqrt(q.shape[-1]))
        q_len=q.shape[1]
        diff=(-q_len)&127
        segment_ids=jnp.zeros((q.shape[0],q.shape[1]),dtype=jnp.int32)
        segment_ids=jnp.pad(segment_ids,((0,0),(0,diff)),mode="constant",constant_values=1)
        # q,k,v=map(lambda x: jnp.pad(x,((0,0),(0,diff),(0,0),(0,0)),mode="constant", constant_values=0),(q,k,v))
        return mha(q,k,v,segment_ids=segment_ids,sm_scale=1/math.sqrt(q.shape[-1]))#[:,:q_len]
    # mha: batch_size, seq_len, num_heads, head_dim = q.shape
    from functools import partial
    from flux.modules.attention_flax import jax_memory_efficient_attention
    try:
        from flash_attn_jax import flash_mha
    except:
        flash_mha = pallas_mha
        # flash_mha = nnx.dot_product_attention


    def dot_product_attention(q, k, v, sm_scale=1.0):
        q,k,v=map(lambda x: rearrange(x, "b h n d -> b n h d"), (q,k,v))
        # ret = pallas_mha(q,k,v)
        ret = nnx.dot_product_attention(q,k,v)
        # if q.shape[-3] % 64 == 0:
        #     query_chunk_size = int(q.shape[-3] / 64)
        # elif q.shape[-3] % 16 == 0:
        #     query_chunk_size = int(q.shape[-3] / 16)
        # elif q.shape[-3] % 4 == 0:
        #     query_chunk_size = int(q.shape[-3] / 4)
        # else:
        #     query_chunk_size = int(q.shape[-3])
        # ret=jax_memory_efficient_attention(q, k, v, query_chunk_size=query_chunk_size)
        return rearrange(ret, "b n h d -> b h n d")

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    q, k = apply_rope(q, k, pe)
    # x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    # q is B H L D
    q,k,v=map(lambda x: rearrange(x, "B H L D -> B L H D"), (q,k,v))
    # x = nnx.dot_product_attention(q,k,v)
    x = flash_mha(q,k,v)
    # x = pallas_mha(q,k,v)
    # x = mha(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1]))
    x = rearrange(x, "B L H D -> B L (H D)")

    # x = rearrange(x, "B H L D -> B L (H D)")

    return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    # scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim
    omega = 1.0 / (theta**scale)
    # out = torch.einsum("...n,d->...nd", pos, omega)
    out = jnp.einsum("...n,d->...nd", pos.astype(jnp.float32), omega)
    # out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    # return out.float()
    return out.astype(jnp.float32)

def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    # xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    # xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    # return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
    return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)