Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import re | |
import math | |
from transformers.modeling_utils import ( | |
PreTrainedModel, | |
apply_chunking_to_forward, | |
find_pruneable_heads_and_indices, | |
prune_linear_layer, | |
) | |
from typing import Any, Optional, Tuple, Union | |
from transformers.modeling_outputs import BaseModelOutput | |
from functools import partial | |
from einops import rearrange | |
import numpy as np | |
import random | |
# Copyright (c) Alibaba Cloud. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import OrderedDict | |
import math | |
import requests | |
from io import BytesIO | |
from functools import partial | |
from PIL import Image | |
from typing import Callable, Optional, Sequence, Tuple, List, Union | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.init import trunc_normal_ | |
from torchvision import transforms | |
from torchvision.transforms import InterpolationMode | |
def build_matryoshka_query_transformer(config, delay_load=False, **kwargs): | |
target_sequence_length = 256 | |
import math | |
grid_size = int(math.sqrt(target_sequence_length)) | |
resampler = Resampler( | |
grid_size=grid_size, | |
embed_dim =config.hidden_size, | |
num_heads = 16, | |
kv_dim=1024, | |
) | |
return resampler | |
def get_matry_n(num_visual_tokens): | |
if num_visual_tokens == 'first_stage': | |
return 256 | |
elif num_visual_tokens == 'second_stage': | |
matry_list = range(2, 258, 2) | |
return random.choice(matry_list) | |
try: | |
num_visual_tokens = int(num_visual_tokens) | |
if 1 <= num_visual_tokens <= 256: | |
return num_visual_tokens | |
except (ValueError, TypeError): | |
print('The num_visual_tokens is should be an integer between 1 and 256') | |
raise ValueError(f"Invalid input: {num_visual_tokens}") | |
def get_abs_pos(abs_pos, tgt_size): | |
# abs_pos: L, C | |
# tgt_size: (H, W) | |
# return: M, C | |
src_size = int(math.sqrt(abs_pos.size(0))) | |
# tgt_size = int(math.sqrt(tgt_size)) | |
dtype = abs_pos.dtype | |
return F.interpolate( | |
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), | |
size=(tgt_size[0], tgt_size[1]), | |
mode="bicubic", | |
align_corners=False, | |
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) | |
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 | |
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
""" | |
grid_size: int of the grid height and width | |
return: | |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
grid_h = np.arange(grid_size, dtype=np.float32) | |
grid_w = np.arange(grid_size, dtype=np.float32) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size, grid_size]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float32) | |
omega /= embed_dim / 2. | |
omega = 1. / 10000 ** omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class Resampler(nn.Module): | |
""" | |
A 2D perceiver-resampler network with one cross attention layers by | |
(grid_size**2) learnable queries and 2d sincos pos_emb | |
Outputs: | |
A tensor with the shape of (grid_size**2, embed_dim) | |
""" | |
def __init__( | |
self, | |
grid_size, | |
embed_dim, | |
num_heads, | |
kv_dim=None, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
): | |
super().__init__() | |
self.num_queries = grid_size ** 2 | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.pos_embed = nn.Parameter( | |
torch.from_numpy(get_2d_sincos_pos_embed(kv_dim, grid_size)).half() | |
).requires_grad_(False) | |
self.query = nn.Parameter(torch.zeros(self.num_queries, kv_dim)) | |
trunc_normal_(self.query, std=.02) | |
self.attn = nn.MultiheadAttention(kv_dim, num_heads) | |
self.ln_q = norm_layer(kv_dim) | |
self.ln_k = norm_layer(kv_dim) | |
self.ln_v = norm_layer(kv_dim) | |
# self.ln_post = norm_layer(kv_dim) | |
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(kv_dim, embed_dim)) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, x, num_visual_tokens=256, tgt_size=(24,24), attn_mask=None): | |
pos_embed = get_abs_pos(self.pos_embed, tgt_size) | |
x = (x).permute(1, 0, 2) | |
N = x.shape[1] | |
matry_n = get_matry_n(num_visual_tokens) | |
#print("number of visual tokens is:",matry_n) | |
q = (self.query[:matry_n]) | |
q = self._repeat(q, N) | |
out = self.attn( | |
self.ln_q(q + self.pos_embed[:matry_n].unsqueeze(1)), | |
self.ln_k(self._repeat(pos_embed, N)), | |
self.ln_v(x), | |
attn_mask=attn_mask)[0] | |
x = out.permute(1, 0, 2) | |
x = x @ self.proj | |
return x | |
def _repeat(self, query, N: int): | |
return query.unsqueeze(1).repeat(1, N, 1) | |