Spaces:
Sleeping
Sleeping
import math | |
import torch | |
from torch import nn | |
class RoPEPositionEncodingSine(nn.Module): | |
""" | |
This is a sinusoidal position encoding that generalized to 2-dimensional images | |
""" | |
def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True): | |
""" | |
Args: | |
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels | |
""" | |
super().__init__() | |
i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1] | |
j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1] | |
assert npe is not None | |
train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W | |
i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W | |
div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4))) | |
div_term = div_term[None, None, :] # [1, 1, C//4] | |
sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) | |
cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) | |
sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term) | |
sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term) | |
cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term) | |
cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term) | |
sin = sin.repeat_interleave(2, dim=-1) | |
cos = cos.repeat_interleave(2, dim=-1) | |
self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2] | |
self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2] | |
def forward(self, x, ratio=1): | |
""" | |
Args: | |
x: [N, H, W, C] | |
""" | |
return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :]) | |
def rotate_half(self, x): | |
x = x.unflatten(-1, (-1, 2)) | |
x1, x2 = x.unbind(dim=-1) | |
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) |