Spaces:
Runtime error
Runtime error
# import numpy as np | |
# import torch | |
# import torch.nn as nn | |
# from math import pi | |
# from einops import rearrange, repeat | |
# | |
# ################################################################################# | |
# # Sine/Cosine Positional Embedding Functions # | |
# ################################################################################# | |
# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py | |
# | |
# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): | |
# """ | |
# grid_size: int of the grid height and width | |
# return: | |
# pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
# """ | |
# grid_h = np.arange(grid_size, dtype=np.float32) | |
# grid_w = np.arange(grid_size, dtype=np.float32) | |
# grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
# grid = np.stack(grid, axis=0) | |
# | |
# grid = grid.reshape([2, 1, grid_size, grid_size]) | |
# pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
# if cls_token and extra_tokens > 0: | |
# pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
# return pos_embed | |
# | |
# | |
# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
# assert embed_dim % 2 == 0 | |
# | |
# # use half of dimensions to encode grid_h | |
# emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) | |
# emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) | |
# | |
# emb = np.concatenate([emb_h, emb_w], axis=1) | |
# return emb | |
# | |
# | |
# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
# """ | |
# embed_dim: output dimension for each position | |
# pos: a list of positions to be encoded: size (M,) | |
# out: (M, D) | |
# """ | |
# assert embed_dim % 2 == 0 | |
# omega = np.arange(embed_dim // 2, dtype=np.float64) | |
# omega /= embed_dim / 2. | |
# omega = 1. / 10000**omega | |
# | |
# pos = pos.reshape(-1) | |
# out = np.einsum('m,d->md', pos, omega) | |
# | |
# emb_sin = np.sin(out) | |
# emb_cos = np.cos(out) | |
# | |
# emb = np.concatenate([emb_sin, emb_cos], axis=1) | |
# return emb | |
# | |
# def broadcat(tensors, dim=-1): | |
# num_tensors = len(tensors) | |
# shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
# assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' | |
# shape_len = list(shape_lens)[0] | |
# dim = (dim + shape_len) if dim < 0 else dim | |
# dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
# expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
# assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' | |
# max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
# expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
# expanded_dims.insert(dim, (dim, dims[dim])) | |
# expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
# tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
# return torch.cat(tensors, dim=dim) | |
# | |
# | |
# def rotate_half(x): | |
# x = rearrange(x, '... (d r) -> ... d r', r=2) | |
# x1, x2 = x.unbind(dim=-1) | |
# x = torch.stack((-x2, x1), dim=-1) | |
# return rearrange(x, '... d r -> ... (d r)') | |
# | |
# ################################################################################# | |
# # VisionRotary # | |
# ################################################################################# | |
# # References: | |
# # EVA: https://github.com/baaivision/EVA | |
# # Transformer升级之路:2、博采众长的旋转式位置编码: https://spaces.ac.cn/archives/8265 | |
# # Transformer升级之路:4、二维位置的旋转式位置编码: https://spaces.ac.cn/archives/8397 | |
# | |
# class VisionRotaryEmbeddingFast(nn.Module): | |
# def __init__( | |
# self, | |
# dim, | |
# pt_hw=(int, int), # (H, W) | |
# ft_hw=None, | |
# custom_freqs = None, | |
# freqs_for = 'lang', | |
# theta = 10000, | |
# max_freq = 10, | |
# num_freqs = 1, | |
# ): | |
# super().__init__() | |
# # Unlike a 1d RoPE, a 2d RoPE requires splitting the dimension into four parts | |
# # References: https://spaces.ac.cn/archives/8397 | |
# | |
# if custom_freqs: | |
# freqs = custom_freqs | |
# elif freqs_for == 'lang': | |
# freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
# elif freqs_for == 'pixel': | |
# freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
# elif freqs_for == 'constant': | |
# freqs = torch.ones(num_freqs).float() | |
# else: | |
# raise ValueError(f'unknown modality {freqs_for}') | |
# | |
# if ft_hw is None: ft_hw = pt_hw | |
# h_t = torch.arange(ft_hw[0]) / ft_hw[0] * pt_hw[0] | |
# w_t = torch.arange(ft_hw[1]) / ft_hw[1] * pt_hw[1] | |
# | |
# h_freqs = torch.einsum('..., f -> ... f', h_t, freqs) | |
# w_freqs = torch.einsum('..., f -> ... f', w_t, freqs) | |
# | |
# h_freqs = repeat(h_freqs, '... n -> ... (n r)', r=2) | |
# w_freqs = repeat(w_freqs, '... n -> ... (n r)', r=2) | |
# | |
# freqs = broadcat((h_freqs[:, None, :], w_freqs[None, :, :]), dim=-1) | |
# freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) | |
# freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) | |
# | |
# self.register_buffer("freqs_cos", freqs_cos) | |
# self.register_buffer("freqs_sin", freqs_sin) | |
# | |
# def forward(self, t): | |
# # 2d RoPE: [[cos(h*theta), -sin(h*theta), 0, 0 ], | |
# # [sin(h*theta), cos(h*theta), 0, 0 ], | |
# # [0, 0, cos(w*theta), -sin(w*theta)], | |
# # [0, 0, sin(w*theta), cos(w*theta) ],] | |
# | |
# return t * self.freqs_cos + rotate_half(t) * self.freqs_sin |