Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair | |
class SequenceEncoderEmbedding(nn.Module): | |
"""Embedding module for encoding sequence inputs, like captions or a sequence of objects. | |
Args: | |
vocab_size: Vocabulary size | |
max_length: Maximum number of tokens in the sequence | |
dim_tokens: Dimension of output tokens. Can be set using init method. | |
sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings | |
max_sincos_pos_emb: Maximum allowed length for sin-cos positional embeddings | |
padding_idx: Padding index for word embedding | |
""" | |
def __init__(self, | |
vocab_size: int, | |
max_length: int, | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
max_sincos_pos_emb: int = 512, | |
padding_idx: int = 0, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.max_length = max_length | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.padding_idx = padding_idx | |
self.max_sincos_pos_emb = max_sincos_pos_emb | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768, init_std=0.02): | |
""" | |
Initialize parts of embedding module that are dependent on dimension of tokens. | |
Should be called when setting up FourM. | |
Args: | |
dim_tokens: Dimension of tokens | |
init_std: Standard deviation of init | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
if self.sincos_pos_emb: | |
if self.max_length > self.max_sincos_pos_emb: | |
raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}") | |
pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length] | |
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens)) | |
nn.init.normal_(self.pos_emb, std=init_std) | |
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) | |
nn.init.normal_(self.mod_emb, std=init_std) | |
# Token embedding | |
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens, | |
padding_idx=self.padding_idx) | |
def no_weight_decay(self): | |
return set() | |
def forward(self, d : Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward pass through embedding module, transforming sequence of ids to sequence of embeddings. | |
Creates corresponding modality and positional embeddings and adds them to the dict. | |
Args: | |
d (Dict[str, torch.Tensor]): Modality dict with at least the following keys: | |
- 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L) where B is the batch size and L is the sequence length. | |
- 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L). | |
Returns: | |
Dict[str, torch.Tensor]: Modality dict with added keys: | |
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the embedding dimension. | |
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D). | |
""" | |
ids = d['tensor'] | |
B = ids.shape[0] | |
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
# Map to embedding | |
x = self.token_emb(ids) | |
expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B) | |
# Input pos encoding | |
input_mask = d['input_mask'] | |
input_pos_id = (~input_mask).int().cumsum(dim=1) - 1 | |
input_pos_id[input_mask] = 0 | |
input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2])) | |
input_pos_emb[input_mask] = 0 | |
x_emb = input_pos_emb + self.mod_emb | |
d['x'] = x | |
d['emb'] = x_emb | |
return d | |
class ImageTokenEncoderEmbedding(nn.Module): | |
"""Embedding module for tokenized spatial inputs. | |
Args: | |
vocab_size: Vocabulary size | |
patch_size: Int or tuple of the patch size over the full image size. | |
dim_tokens: Dimension of output tokens. Can be set using init method. | |
sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
image_size: Default image size. Used to initialize size of positional embeddings. | |
""" | |
def __init__(self, | |
vocab_size: int, | |
patch_size: Union[int, Tuple[int,int]] = 16, | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
image_size: Union[int, Tuple[int]] = 224, | |
**kwargs): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.patch_size = pair(patch_size) | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.image_size = pair(image_size) | |
self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size) | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768, init_std=0.02): | |
""" | |
Initialize parts of module that are dependent on dimension of tokens. | |
Should be called when setting up FourM. | |
Args: | |
dim_tokens: Dimension of tokens | |
init_std: Standard deviation of init | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
h_posemb = self.image_size[0] // self.patch_size[0] | |
w_posemb = self.image_size[1] // self.patch_size[1] | |
if self.sincos_pos_emb: | |
pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens)) | |
nn.init.normal_(self.pos_emb, std=init_std) | |
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) | |
nn.init.normal_(self.mod_emb, std=init_std) | |
# Token embedding | |
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens) | |
def no_weight_decay(self): | |
return set() | |
def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward pass through embedding module, transforming image tokens to a sequence of embeddings. | |
Creates corresponding modality and positional embeddings and adds them to the dict. | |
Args: | |
d (Dict[str, torch.Tensor]): Modality dict with at least the following key: | |
- 'tensor' (torch.Tensor): Input image tokens for each batch. Shape (B, H, W) where B is the batch size, and H, W are height and width of the tokenized image. - 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L). | |
Returns: | |
Dict[str, torch.Tensor]: Modality dictionary with added keys: | |
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, H*W, D). | |
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, H*W, D). | |
""" | |
ids = d['tensor'] | |
B = ids.shape[0] | |
ids = ids.reshape(B, -1) | |
# Map to embedding | |
x = self.token_emb(ids) | |
# Create positional embedding + modality embedding | |
x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B) | |
d['x'] = x | |
d['emb'] = x_emb | |
return d | |
class ImageEncoderEmbedding(nn.Module): | |
"""Embedding module for spatial inputs, like images or feature maps. | |
Creates tokens from patches over the image. | |
This adapter / embedding differs from the one of MultiMAE by taking as input a dict and | |
separating positional embeddings and modality embeddings from the input projection | |
Input projection is 'x', posemb + modemb is 'emb' | |
Args: | |
num_channels: Number of input channels of the image/feature map | |
patch_size: Int or tuple of the patch size over the full image size. | |
dim_tokens: Dimension of output tokens. Can be set using init method. | |
sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
image_size: Default image size. Used to initialize size of positional embeddings. | |
""" | |
def __init__(self, | |
num_channels: int, | |
patch_size: Union[int, Tuple[int,int]], | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
image_size: Union[int, Tuple[int]] = 224): | |
super().__init__() | |
self.num_channels = num_channels | |
self.patch_size = pair(patch_size) | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.image_size = pair(image_size) | |
self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size) | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768, init_std=0.02): | |
""" | |
Initialize parts of encoder that are dependent on dimension of tokens. | |
Should be called when setting up FourM. | |
Args: | |
dim_tokens: Dimension of tokens | |
init_std: Standard deviation of init | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
h_posemb = self.image_size[0] // self.patch_size[0] | |
w_posemb = self.image_size[1] // self.patch_size[1] | |
if self.sincos_pos_emb: | |
pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens)) | |
nn.init.normal_(self.pos_emb, std=init_std) | |
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) | |
nn.init.normal_(self.mod_emb, std=init_std) | |
# Image -> tokens projection | |
# No bias term here, so modality embedding fully comes from self.mod_emb | |
self.proj = nn.Linear(self.num_channels * self.patch_size[0] * self.patch_size[1], self.dim_tokens, bias=False) | |
def no_weight_decay(self): | |
return set() | |
def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward pass through embedding module, transforming image to sequence of tokens. | |
Creates corresponding modality and positional embeddings and adds them to the dict. | |
Args: | |
d (Dict[str, torch.Tensor]): Modality dict with at least the following key: | |
- 'tensor' (torch.Tensor): Input image for each batch. Shape (B, C, H, W) where B is the batch size, C is the number of channels, and H, W are height and width of the image. | |
Returns: | |
Dict[str, torch.Tensor]: Modality dict with added keys: | |
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, (H / PH) * (W / PW), D), where PH and PW are the patch sizes | |
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, (H / PH) * (W / PW), D) | |
""" | |
x = d['tensor'] | |
B, C, H, W = x.shape | |
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
assert (H % self.patch_size[0] == 0) and (W % self.patch_size[1] == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.patch_size[0]}x{self.patch_size[1]}' | |
# Create patches [B, C, H, W] -> [B, (H*W), C] | |
x_patch = self.proj(rearrange(x, 'b d (nh ph) (nw pw) -> b (nh nw) (ph pw d)', ph=self.patch_size[0], pw=self.patch_size[1])) | |
# Create positional embedding + modality embedding | |
x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B) | |
d['x'] = x_patch | |
d['emb'] = x_emb | |
return d | |
class SequenceEmbEncoderEmbedding(nn.Module): | |
"""Adapter for sequence emb inputs, like T5-XXL, CLIP text embeddings. | |
Args: | |
max_length: Maximum number of tokens in the sequence | |
dim_tokens: Dimension of output tokens. Can be set using init method. | |
sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings | |
padding_idx: Padding index for word embedding | |
orig_emb_dim: Dimension of original embeddings | |
bottleneck_dim: Dimension of bottleneck layer | |
use_bottleneck: Set to True to use bottleneck layer | |
""" | |
def __init__(self, | |
max_length: int, | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
max_sincos_pos_emb: int = 512, | |
padding_idx: int = 0, | |
orig_emb_dim: int = 4096, | |
bottleneck_dim: int = 64, | |
use_bottleneck: bool = False, | |
): | |
super().__init__() | |
self.max_length = max_length | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.padding_idx = padding_idx | |
self.max_sincos_pos_emb = max_sincos_pos_emb | |
self.orig_emb_dim = orig_emb_dim | |
self.use_bottleneck = use_bottleneck | |
if self.use_bottleneck: | |
self.bottleneck_dim = bottleneck_dim | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768, init_std=0.02): | |
""" | |
Initialize parts of embedding module that are dependent on dimension of tokens. | |
Should be called when setting up FourM. | |
Args: | |
dim_tokens: Dimension of tokens | |
init_std: Standard deviation of init | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
if self.sincos_pos_emb: | |
if self.max_length > self.max_sincos_pos_emb: | |
raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}") | |
pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length] | |
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens)) | |
nn.init.normal_(self.pos_emb, std=init_std) | |
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) | |
nn.init.normal_(self.mod_emb, std=init_std) | |
# Token embedding projection | |
if self.use_bottleneck: | |
self.emb_proj = nn.Sequential( | |
nn.Linear(self.orig_emb_dim, self.bottleneck_dim), | |
nn.Linear(self.bottleneck_dim, self.dim_tokens), | |
) | |
else: | |
self.emb_proj = nn.Linear(self.orig_emb_dim, self.dim_tokens) | |
def no_weight_decay(self): | |
return set() | |
def forward(self, d): | |
""" | |
Forward pass through embedding module, projecting original embeddings to the Transformer dimension. | |
Creates corresponding modality and positional embeddings and adds them to the dict. | |
Args: | |
d (Dict[str, torch.Tensor]): Modality dict with at least the following keys: | |
- 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L, E) where B is the batch size and L is the sequence length, and E is the dimension of the original embeddings. | |
- 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L). | |
Returns: | |
Dict[str, torch.Tensor]: Modality dict with added keys: | |
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the Transformer embedding dimension. | |
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D). | |
""" | |
orig_emb = d['tensor'] | |
B = orig_emb.shape[0] | |
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
# Map to embedding | |
x = self.emb_proj(orig_emb) | |
expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B) | |
# Input pos encoding | |
input_mask = d['input_mask'] | |
input_pos_id = (~input_mask).int().cumsum(dim=1) - 1 | |
input_pos_id[input_mask] = 0 | |
input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2])) | |
input_pos_emb[input_mask] = 0 | |
x_emb = input_pos_emb + self.mod_emb | |
d['x'] = x | |
d['emb'] = x_emb | |
return d | |