Apollo-LMMs-Apollo-3B-t32 / mm_connector.py
hawky-ai-labs's picture
Upload folder using huggingface_hub
7441f42 verified
import re, math, torch
from collections import OrderedDict
from typing import Optional, Tuple
from torch import nn
from torch.nn.init import trunc_normal_, normal_
import torch.utils.checkpoint
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {"silu": nn.SiLU}
ACT2FN = ClassInstantier(ACT2CLS)
class WeightedNorm(nn.Module):
def __init__(self, hidden_size):
"""
WeightedNorm
"""
super().__init__()
self.hidden_size = hidden_size
self.norm = nn.LayerNorm(self.hidden_size)
self.wheight = nn.Parameter(torch.ones(self.hidden_size))
normal_(self.wheight, mean=1, std=.02)
def forward(self, x):
x = self.norm(x)
return x * self.wheight
class PerceiverMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
output_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class PerceiverAttention(nn.Module):
def __init__(self, connector_config, layer_idx: Optional[int] = None) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__()
self.layer_idx = None
self.hidden_size = connector_config.text_hidden_size
self.num_heads = connector_config.resampler_n_heads
self.head_dim = connector_config.resampler_head_dim
self.num_key_value_heads = connector_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
Args:
latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
"""
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
PERCEIVER_ATTENTION_CLASSES = {
"eager": PerceiverAttention,
}
class PerceiverLayer(nn.Module):
def __init__(self, connector_config, layer_idx: int):
super().__init__()
self.hidden_size = connector_config.text_hidden_size
self.n_latents = connector_config.num_output_tokens
self.depth = connector_config.resampler_depth
self.ff_multi = connector_config.ff_multi
self.input_latents_norm = WeightedNorm(self.hidden_size)
self.input_context_norm = WeightedNorm(self.hidden_size)
self.self_attn = PERCEIVER_ATTENTION_CLASSES[connector_config._attn_implementation](connector_config,
layer_idx=layer_idx)
self.post_attention_layernorm = WeightedNorm(self.hidden_size)
self.mlp = PerceiverMLP(
hidden_size=connector_config.text_hidden_size,
intermediate_size=connector_config.text_hidden_size * self.ff_multi,
output_size=connector_config.text_hidden_size,
hidden_act=connector_config.hidden_act,
)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = latents
latents = self.input_latents_norm(latents)
context = self.input_context_norm(context)
latents, self_attn_weights, present_key_value = self.self_attn(
latents=latents,
context=context,
)
latents = residual + latents
residual = latents
latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = residual + latents
outputs = (latents,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class PerceiverResampler(nn.Module):
"""Perceiver Resampler that compresses input embeddings into a fixed number of latents."""
def __init__(self, connector_config) -> None:
super().__init__()
self.hidden_size = connector_config.text_hidden_size
self.hidden_act = connector_config.hidden_act
self.n_latents = connector_config.num_output_tokens
self.depth = connector_config.resampler_depth
# Create Latents for Perceiver
self.latents = nn.Parameter(torch.zeros(self.n_latents, self.hidden_size))
# Create Transformer Blocks
self.layers = nn.ModuleList([PerceiverLayer(connector_config, idx) for idx in range(self.depth)])
self.norm = WeightedNorm(self.hidden_size)
self._use_flash_attention_2 = connector_config._attn_implementation == "flash_attention_2"
def forward(
self,
context: torch.Tensor,
attention_mask: torch.Tensor = None,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
compressed_context = latents
for i, perceiver_layer in enumerate(self.layers):
layer_outputs = perceiver_layer(
compressed_context,
context,
past_key_value=None,
output_attentions=False,
use_cache=False,
)
compressed_context = layer_outputs[0]
compressed_context = self.norm(compressed_context)
return compressed_context
def build_mm_projector(
input_dim,
output_dim,
projector_type,
hidden_act='silu',
delay_load=False,
token_input_shape=0,
**kwargs
) -> nn.Sequential:
modules = [nn.Linear(input_dim, output_dim)]
mlp_gelu_match = re.match(r'.*mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match is not None:
mlp_depth = int(mlp_gelu_match.group(1))
for _ in range(mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(output_dim, output_dim))
return nn.Sequential(*modules)
class MMConnector(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config: PretrainedConfig) -> None:
super().__init__(config)
self.proj = build_mm_projector(config.vision_hidden_size, config.text_hidden_size,
config.projector_type, token_input_shape=config.token_input_shape)
self.resampler = PerceiverResampler(config)
def forward(self, x):
x = self.proj(x)
x = self.resampler(x)
return x