Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional | |
from torch import nn | |
from einops import rearrange | |
class BottleneckBlock(nn.Module): | |
def __init__(self, thin, wide): | |
super(BottleneckBlock, self).__init__() | |
self.block = nn.Sequential( | |
nn.Linear(thin, wide), | |
nn.GELU(), | |
nn.Linear(wide, thin) | |
) | |
def forward(self, x): | |
out = self.block(x) | |
return out | |
class StandardMLP(nn.Module): | |
def __init__(self, dim_in, dim_out, widths): | |
super(StandardMLP, self).__init__() | |
self.dim_in = dim_in | |
self.dim_out = dim_out | |
self.widths = widths | |
self.linear_in = nn.Linear(self.dim_in, self.widths[0]) | |
self.linear_out = nn.Linear(self.widths[-1], self.dim_out) | |
self.layers = [] | |
self.layer_norms = [] | |
for i in range(len(self.widths) - 1): | |
self.layers.append(nn.Linear(self.widths[i], self.widths[i + 1])) | |
self.layer_norms.append(nn.LayerNorm(widths[i + 1])) | |
self.layers = nn.ModuleList(self.layers) | |
self.layernorms = nn.ModuleList(self.layer_norms) | |
def forward(self, x): | |
# If x is an image, apply MLP point-wise to each token/pixel | |
if x.ndim == 4: | |
_, _, H, W = x.shape | |
x = rearrange(x, 'b d h w -> b (h w) d') | |
x_is_image = True | |
else: | |
x_is_image = False | |
z = self.linear_in(x) | |
for layer, norm in zip(self.layers, self.layer_norms): | |
z = norm(z) | |
z = layer(z) | |
out = self.linear_out(z) | |
# If x was an image, rearrange back to image | |
if x_is_image: | |
out = rearrange(out, 'b (h w) d -> b d h w', h=H, w=W) | |
return out | |
class BottleneckMLP(nn.Module): | |
def __init__(self, dim_in, dim_out, block_dims): | |
super(BottleneckMLP, self).__init__() | |
self.dim_in = dim_in | |
self.dim_out = dim_out | |
self.block_dims = block_dims | |
self.linear_in = nn.Linear(self.dim_in, self.block_dims[0][1]) | |
self.linear_out = nn.Linear(self.block_dims[-1][1], self.dim_out) | |
blocks = [] | |
layernorms = [] | |
for block_dim in self.block_dims: | |
wide, thin = block_dim | |
blocks.append(BottleneckBlock(thin=thin, wide=wide)) | |
layernorms.append(nn.LayerNorm(thin)) | |
self.blocks = nn.ModuleList(blocks) | |
self.layernorms = nn.ModuleList(layernorms) | |
def forward(self, x): | |
# If x is an image, apply MLP point-wise to each token/pixel | |
if x.ndim == 4: | |
_, _, H, W = x.shape | |
x = rearrange(x, 'b d h w -> b (h w) d') | |
x_is_image = True | |
else: | |
x_is_image = False | |
x = self.linear_in(x) | |
for block, norm in zip(self.blocks, self.layernorms): | |
x = x + block(norm(x)) | |
out = self.linear_out(x) | |
# If x was an image, rearrange back to image | |
if x_is_image: | |
out = rearrange(out, 'b (h w) d -> b d h w', h=H, w=W) | |
return out | |
def build_mlp(model_id: str = "BottleneckMLP/B_6-Wi_1024", | |
dim_in: Optional[int] = None, | |
dim_out: Optional[int] = None, | |
**kwargs) -> nn.Module: | |
"""Constructs an MLP model from a model ID string, see | |
"Scaling MLPs: A Tale of Inductive Bias" (https://arxiv.org/abs/2306.13575). | |
Args: | |
model_id: Model ID string. E.g. "BottleneckMLP/B_6-Wi_1024". | |
See https://arxiv.org/abs/2306.13575 for options and details. | |
dim_in: Input dimensionality. If None, defaults to MLP dimension. | |
dim_out: Output dimensionality. If None, defaults to MLP dimension. | |
Returns: | |
MLP model. | |
""" | |
model, architecture = model_id.split("/") | |
assert model in ["BottleneckMLP", "MLP"], f"Model {model} not supported." | |
sep = architecture.split("-") | |
num_blocks = int(sep[0].split("_")[1]) | |
thin = int(sep[1].split("_")[1]) | |
# If dim_in and dim_out are not specified, use MLP dim | |
dim_in = dim_in or thin | |
dim_out = dim_out or thin | |
if len(sep) == 3: | |
expansion_factor = int(sep[2].split("_")[1]) | |
else: | |
expansion_factor = 4 | |
if model == "BottleneckMLP": | |
blocks = [[expansion_factor * thin, thin] for _ in range(num_blocks)] | |
return BottleneckMLP( | |
dim_in=dim_in, | |
dim_out=dim_out, | |
block_dims=blocks, | |
) | |
elif model == "MLP": | |
blocks = [thin for _ in range(num_blocks)] | |
return StandardMLP( | |
dim_in=dim_in, | |
dim_out=dim_out, | |
widths=blocks, | |
) |