Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,834 Bytes
c4a668c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class DualMLPProjector(nn.Module):
def __init__(self, config, mlp_depth):
super().__init__()
self.encoder_mlp = nn.Sequential(
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
*[nn.Sequential(nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size)) for _ in range(mlp_depth-1)]
)
def forward(self, image_features, encoder_last_hidden_state):
encoder_last_hidden_state = torch.cat((image_features, encoder_last_hidden_state), dim=-1)
concatenated = self.encoder_mlp(encoder_last_hidden_state)
return concatenated
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
return DualMLPProjector(config, mlp_depth)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
|