pit / models_jittor /utils.py
mu123567's picture
Upload 1153 files
612d32b
import math
import jittor as jt
import jittor.nn as nn
class NewGELUActivation(jt.Module):
def execute(self, input):
output = (input + 0.044715 * jt.pow(input.float(), 3))
if jt.flags.amp_level >= 1:
output = output.half()
return 0.5 * input * (1.0 + jt.tanh(math.sqrt(2.0 / math.pi) * output))
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (jt.arange(0, dim, 2) / dim))
sinusoid_inp = (
jt.einsum("i , j -> i j", jt.arange(seq_len, dtype=jt.float), inv_freq).float()
)
if jt.flags.use_tensorcore:
sinusoid_inp = sinusoid_inp.half()
return jt.sin(sinusoid_inp), jt.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = jt.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
def duplicate_interleave(m):
"""
A simple version of `jt.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
def _init_weights(module, config):
if isinstance(module, (nn.Linear,)):
# Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _convert_head_mask_to_5d(head_mask, num_hidden_layers, dtype):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=dtype) # switch to float if need + fp16 compatibility
return head_mask
def get_head_mask(
head_mask, num_hidden_layers: int,
is_attention_chunked: bool = False
):
if head_mask is not None:
head_mask = _convert_head_mask_to_5d(head_mask, num_hidden_layers, 'float16')
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask