jiuhai's picture
llama
c4a668c
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class DualMLPProjector(nn.Module):
def __init__(self, config, mlp_depth):
super().__init__()
self.encoder_mlp = nn.Sequential(
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
*[nn.Sequential(nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size)) for _ in range(mlp_depth-1)]
)
def forward(self, image_features, encoder_last_hidden_state):
encoder_last_hidden_state = torch.cat((image_features, encoder_last_hidden_state), dim=-1)
concatenated = self.encoder_mlp(encoder_last_hidden_state)
return concatenated
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
return DualMLPProjector(config, mlp_depth)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')