# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torchvision.transforms import Resize from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel from . import models from .utils import ScalingLayer class TextAlignedTokenizer(nn.Module): def __init__( self, bottleneck, bottleneck_token_num=256, input_size=384, teacher='google/siglip2-so400m-patch14-384', input_type='quant', # choose from ['quant', 'rec', 'indices'] pool_scale=1, # choose from [1, 2, 3] decoder_depth=3, select_layer_id=-2, *args, **kwargs ): super().__init__() self.input_size = input_size self.bottleneck_token_num = bottleneck_token_num self.teacher = teacher self.input_type = input_type self.pool_scale = pool_scale self.decoder_depth = decoder_depth self.select_layer_id = select_layer_id self.bottleneck_dim = bottleneck['args']['bottleneck_dim'] self.encoder_config = AutoConfig.from_pretrained(teacher) self.encoder = AutoModel.from_config(self.encoder_config).vision_model self.encoder_hidden_dim = self.encoder.config.hidden_size self.decoder_config = Siglip2VisionConfig() self.decoder_config.update({ 'patch_size': 1, 'num_hidden_layers': self.decoder_depth, 'num_channels': self.bottleneck_dim, 'hidden_size': self.encoder_hidden_dim, }) self.decoder = Siglip2VisionModel(self.decoder_config) self.encode_task_layer = nn.Sequential( nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), nn.Tanh()) self.decode_task_layer = nn.Sequential( nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), nn.Tanh(), nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim)) bottleneck_args = { 'token_nums': self.bottleneck_token_num, 'input_dim': self.encoder_hidden_dim, 'output_dim': self.bottleneck_dim} self.bottleneck = models.make(bottleneck, args=bottleneck_args) self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.image_resize = Resize((self.input_size, self.input_size)) def set_vq_eval_deterministic(self, deterministic=True): self.bottleneck.regularizer.set_eval_deterministic(deterministic) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype @classmethod def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs): ckpt = torch.load(ckpt, map_location='cpu') ckpt_kwargs = ckpt["model"]["args"] model = cls(**kwargs, **ckpt_kwargs) sd = ckpt["model"]["sd"] if not load_teacher: sd = {k: v for k, v in sd.items() if not k.startswith('teacher')} model.load_state_dict(sd, strict=True) return model def encode(self, x, **kwargs): if x.ndim == 5: x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.scale_layer(x) if tuple(x.shape[-2:]) != (self.input_size, self.input_size): x = self.image_resize(x) vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[self.select_layer_id] pool_scale = self.pool_scale pool_scale = kwargs.get("pool_scale", pool_scale) if pool_scale != 1: vq_feats = self.avg_pool(vq_feats, pool_scale) vq_feats = self.encode_task_layer(vq_feats.to(x)) bottleneck_out = self.bottleneck(vq_feats) z = bottleneck_out.pop('output') return {'encoded': z, 'pool_scale': pool_scale, 'vq_feats': vq_feats, **bottleneck_out} def avg_pool(self, z, pool_scale=1): if z.ndim == 3: b, n, c = z.shape p = int(n ** 0.5) z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p) else: b, c, p, _ = z.shape p_s = int(p // pool_scale) z = F.avg_pool2d( z, kernel_size=(pool_scale, pool_scale), stride=(pool_scale, pool_scale) ).contiguous() z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c') return z def decode(self, z): if z.ndim == 4: z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c') attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device) p = int(z.shape[1]**0.5) spatial_shape = torch.tensor([[p, p]]*z.shape[0], device=self.device) z = self.decoder(z, attention_mask, spatial_shape, output_hidden_states=True).last_hidden_state z = self.decode_task_layer(z) return z def decode_from_bottleneck(self, bottleneck_rep): z = self.bottleneck.decode(bottleneck_rep) # (b, n, c) p = int(z.shape[1]**0.5) z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p) return self.decode(z) def forward(self, data, **kwargs): # data: video in shape (b, c, t, h, w) encode_output = self.encode(data, **kwargs) vq_feats = encode_output['encoded'] p = int(vq_feats.shape[1] ** 0.5) vq_feats = rearrange(vq_feats, 'b (h w) c -> b c h w', h=p, w=p) pred_feats = self.decode(vq_feats) if self.input_type == 'quant': z = encode_output["regularized_z"] # [b, n, c] elif self.input_type == 'indices': z = encode_output["bottleneck_rep"] # [b, n] elif self.input_type == 'rec': z = pred_feats # [b, n, c] encode_output['encoded'] = z return encode_output