Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Optional, Union | |
import torch | |
import torch.version | |
from pytest import Cache | |
from torch import nn | |
from transformers import ( | |
AutoConfig, | |
GemmaForCausalLM, | |
PaliGemmaForConditionalGeneration, | |
PretrainedConfig, | |
PreTrainedModel, | |
) | |
from transformers.models.auto import CONFIG_MAPPING | |
from lerobot.common.policies.pi0.flex_attention import flex_attention_forward | |
def apply_rope(x, positions, max_wavelength=10_000): | |
""" | |
Applies RoPE positions [B, L] to x [B, L, H, D]. | |
""" | |
d_half = x.shape[-1] // 2 | |
device = x.device | |
dtype = x.dtype | |
x = x.to(torch.float32) | |
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) | |
timescale = max_wavelength**freq_exponents | |
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) | |
radians = radians[..., None, :] | |
sin = torch.sin(radians) # .to(dtype=dtype) | |
cos = torch.cos(radians) # .to(dtype=dtype) | |
x1, x2 = x.split(d_half, dim=-1) | |
res = torch.empty_like(x) | |
res[..., :d_half] = x1 * cos - x2 * sin | |
res[..., d_half:] = x2 * cos + x1 * sin | |
return res.to(dtype) | |
class PaliGemmaWithExpertConfig(PretrainedConfig): | |
model_type = "PaliGemmaWithExpertModel" | |
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} | |
def __init__( | |
self, | |
paligemma_config: dict | None = None, | |
gemma_expert_config: dict | None = None, | |
freeze_vision_encoder: bool = True, | |
train_expert_only: bool = True, | |
attention_implementation: str = "eager", | |
**kwargs, | |
): | |
self.freeze_vision_encoder = freeze_vision_encoder | |
self.train_expert_only = train_expert_only | |
self.attention_implementation = attention_implementation | |
if paligemma_config is None: | |
# Default config from Pi0 | |
self.paligemma_config = CONFIG_MAPPING["paligemma"]( | |
transformers_version="4.48.1", | |
_vocab_size=257152, | |
bos_token_id=2, | |
eos_token_id=1, | |
hidden_size=2048, | |
image_token_index=257152, | |
model_type="paligemma", | |
pad_token_id=0, | |
projection_dim=2048, | |
text_config={ | |
"hidden_activation": "gelu_pytorch_tanh", | |
"hidden_size": 2048, | |
"intermediate_size": 16384, | |
"model_type": "gemma", | |
"num_attention_heads": 8, | |
"num_hidden_layers": 18, | |
"num_image_tokens": 256, | |
"num_key_value_heads": 1, | |
"torch_dtype": "float32", | |
"vocab_size": 257152, | |
}, | |
vision_config={ | |
"hidden_size": 1152, | |
"intermediate_size": 4304, | |
"model_type": "siglip_vision_model", | |
"num_attention_heads": 16, | |
"num_hidden_layers": 27, | |
"num_image_tokens": 256, | |
"patch_size": 14, | |
"projection_dim": 2048, | |
"projector_hidden_act": "gelu_fast", | |
"torch_dtype": "float32", | |
"vision_use_head": False, | |
}, | |
) | |
elif isinstance(self.paligemma_config, dict): | |
# Override Pi0 default config for PaliGemma | |
if "model_type" not in gemma_expert_config: | |
paligemma_config["model_type"] = "paligemma" | |
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] | |
self.paligemma_config = cfg_cls(**paligemma_config) | |
if gemma_expert_config is None: | |
# Default config from Pi0 | |
self.gemma_expert_config = CONFIG_MAPPING["gemma"]( | |
attention_bias=False, | |
attention_dropout=0.0, | |
bos_token_id=2, | |
eos_token_id=1, | |
head_dim=256, | |
hidden_act="gelu_pytorch_tanh", | |
hidden_activation="gelu_pytorch_tanh", | |
hidden_size=1024, | |
initializer_range=0.02, | |
intermediate_size=4096, | |
max_position_embeddings=8192, | |
model_type="gemma", | |
num_attention_heads=8, | |
num_hidden_layers=18, | |
num_key_value_heads=1, | |
pad_token_id=0, | |
rms_norm_eps=1e-06, | |
rope_theta=10000.0, | |
torch_dtype="float32", | |
transformers_version="4.48.1", | |
use_cache=True, | |
vocab_size=257152, | |
) | |
elif isinstance(self.gemma_expert_config, dict): | |
# Override Pi0 default config for Gemma Expert | |
if "model_type" not in gemma_expert_config: | |
gemma_expert_config["model_type"] = "gemma" | |
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] | |
self.gemma_expert_config = cfg_cls(**gemma_expert_config) | |
super().__init__(**kwargs) | |
def __post_init__(self): | |
super().__post_init__() | |
if self.train_expert_only and not self.freeze_vision_encoder: | |
raise ValueError( | |
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." | |
) | |
if self.attention_implementation not in ["eager", "fa2", "flex"]: | |
raise ValueError( | |
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." | |
) | |
class PaliGemmaWithExpertModel(PreTrainedModel): | |
config_class = PaliGemmaWithExpertConfig | |
def __init__(self, config: PaliGemmaWithExpertConfig): | |
super().__init__(config=config) | |
self.config = config | |
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) | |
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) | |
# Remove unused embed_tokens | |
self.gemma_expert.model.embed_tokens = None | |
self.to_bfloat16_like_physical_intelligence() | |
self.set_requires_grad() | |
def set_requires_grad(self): | |
if self.config.freeze_vision_encoder: | |
self.paligemma.vision_tower.eval() | |
for params in self.paligemma.vision_tower.parameters(): | |
params.requires_grad = False | |
if self.config.train_expert_only: | |
self.paligemma.eval() | |
for params in self.paligemma.parameters(): | |
params.requires_grad = False | |
def train(self, mode: bool = True): | |
super().train(mode) | |
if self.config.freeze_vision_encoder: | |
self.paligemma.vision_tower.eval() | |
if self.config.train_expert_only: | |
self.paligemma.eval() | |
def to_bfloat16_like_physical_intelligence(self): | |
self.paligemma = self.paligemma.to(dtype=torch.bfloat16) | |
params_to_change_dtype = [ | |
"language_model.model.layers", | |
"gemma_expert.model.layers", | |
"vision_tower", | |
"multi_modal", | |
] | |
for name, param in self.named_parameters(): | |
if any(selector in name for selector in params_to_change_dtype): | |
param.data = param.data.to(dtype=torch.bfloat16) | |
def embed_image(self, image: torch.Tensor): | |
return self.paligemma.get_image_features(image) | |
def embed_language_tokens(self, tokens: torch.Tensor): | |
return self.paligemma.language_model.model.embed_tokens(tokens) | |
# TODO: break down this huge forward into modules or functions | |
def forward( | |
self, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, | |
inputs_embeds: List[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
fill_kv_cache: Optional[bool] = None, | |
): | |
models = [self.paligemma.language_model.model, self.gemma_expert.model] | |
for hidden_states in inputs_embeds: | |
# TODO this is very inefficient | |
# dtype is always the same, batch size too (if > 1 len) | |
# device could be trickier in multi gpu edge cases but that's it | |
if hidden_states is None: | |
continue | |
batch_size = hidden_states.shape[0] | |
# RMSNorm | |
num_layers = self.paligemma.config.text_config.num_hidden_layers | |
head_dim = self.paligemma.config.text_config.head_dim | |
for layer_idx in range(num_layers): | |
query_states = [] | |
key_states = [] | |
value_states = [] | |
for i, hidden_states in enumerate(inputs_embeds): | |
if hidden_states is None: | |
continue | |
layer = models[i].layers[layer_idx] | |
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) | |
# hidden_states = hidden_states * normalizer | |
hidden_states = layer.input_layernorm(hidden_states) | |
input_shape = hidden_states.shape[:-1] | |
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) | |
hidden_states = hidden_states.to(dtype=torch.bfloat16) | |
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) | |
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) | |
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) | |
query_states.append(query_state) | |
key_states.append(key_state) | |
value_states.append(value_state) | |
# B,L,H,D with L sequence length, H number of heads, D head dim | |
# concatenate on the number of embeddings/tokens | |
query_states = torch.cat(query_states, dim=1) | |
key_states = torch.cat(key_states, dim=1) | |
value_states = torch.cat(value_states, dim=1) | |
query_states = apply_rope(query_states, position_ids) | |
key_states = apply_rope(key_states, position_ids) | |
if use_cache and past_key_values is None: | |
past_key_values = {} | |
if use_cache: | |
if fill_kv_cache: | |
past_key_values[layer_idx] = { | |
"key_states": key_states, | |
"value_states": value_states, | |
} | |
else: | |
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. | |
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach | |
# the max len, then we (for instance) double the cache size. This implementation already exists | |
# in `transformers`. (molbap) | |
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) | |
value_states = torch.cat( | |
[past_key_values[layer_idx]["value_states"], value_states], dim=1 | |
) | |
attention_interface = self.get_attention_interface() | |
att_output = attention_interface( | |
attention_mask, batch_size, head_dim, query_states, key_states, value_states | |
) | |
att_output = att_output.to(dtype=torch.bfloat16) | |
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) | |
outputs_embeds = [] | |
start = 0 | |
for i, hidden_states in enumerate(inputs_embeds): | |
layer = models[i].layers[layer_idx] | |
if hidden_states is not None: | |
end = start + hidden_states.shape[1] | |
if att_output.dtype != layer.self_attn.o_proj.weight.dtype: | |
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) | |
out_emb = layer.self_attn.o_proj(att_output[:, start:end]) | |
# TODO: first dropout (by default 0.0) | |
# first residual | |
out_emb += hidden_states | |
after_first_residual = out_emb.clone() | |
out_emb = layer.post_attention_layernorm(out_emb) | |
out_emb = layer.mlp(out_emb) | |
# TODO: second dropout (by default 0.0) | |
# second residual | |
out_emb += after_first_residual | |
outputs_embeds.append(out_emb) | |
start = end | |
else: | |
outputs_embeds.append(None) | |
inputs_embeds = outputs_embeds | |
# final norm | |
outputs_embeds = [] | |
for i, hidden_states in enumerate(inputs_embeds): | |
if hidden_states is not None: | |
out_emb = models[i].norm(hidden_states) | |
outputs_embeds.append(out_emb) | |
else: | |
outputs_embeds.append(None) | |
return outputs_embeds, past_key_values | |
def get_attention_interface(self): | |
if self.config.attention_implementation == "fa2": | |
attention_interface = self.flash_attention_forward | |
elif self.config.attention_implementation == "flex": | |
attention_interface = flex_attention_forward | |
else: | |
attention_interface = self.eager_attention_forward | |
return attention_interface | |
def flash_attention_forward( | |
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states | |
): | |
raise NotImplementedError("FA2 is not implemented (yet)") | |
def eager_attention_forward( | |
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states | |
): | |
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads | |
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads | |
num_key_value_groups = num_att_heads // num_key_value_heads | |
# query_states: batch_size, sequence_length, num_att_head, head_dim | |
# key_states: batch_size, sequence_length, num_key_value_head, head_dim | |
# value_states: batch_size, sequence_length, num_key_value_head, head_dim | |
sequence_length = key_states.shape[1] | |
key_states = key_states[:, :, :, None, :].expand( | |
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim | |
) | |
key_states = key_states.reshape( | |
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim | |
) | |
value_states = value_states[:, :, :, None, :].expand( | |
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim | |
) | |
value_states = value_states.reshape( | |
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim | |
) | |
# Attention here is upcasted to float32 to match the original eager implementation. | |
query_states = query_states.to(dtype=torch.float32) | |
key_states = key_states.to(dtype=torch.float32) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) | |
att_weights *= head_dim**-0.5 | |
big_neg = -2.3819763e38 # See gemma/modules.py | |
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) | |
probs = nn.functional.softmax(masked_att_weights, dim=-1) | |
probs = probs.to(dtype=value_states.dtype) | |
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length | |
# value_states: batch_size, sequence_length, num_att_heads, head_dim | |
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) | |
att_output = att_output.permute(0, 2, 1, 3) | |
# we use -1 because sequence length can change | |
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) | |
return att_output | |