enamine_embedding_mapper / modeling_mapper.py
entropy's picture
Upload model
1b36437 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_mapper import MapperConfig
class FeedForward(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.fc1 = nn.Linear(d_in, d_out*2)
self.fc2 = nn.Linear(d_out, d_out)
def forward(self, x):
x = self.fc1(x)
x1, x2 = x.chunk(2, dim=-1)
x = self.fc2(F.silu(x1) * x2)
return x
class FeedForwardLayer(nn.Module):
def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=None):
super().__init__()
self.ff = FeedForward(d_in, d_out)
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.LayerNorm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else None
def forward(self, x):
x = self.dropout(x)
x = self.ff(x) + self.skip(x)
if self.LayerNorm:
x = self.LayerNorm(x)
return x
class Mapper(nn.Module):
def __init__(self, d_in, d_hidden, d_out, n_out, n_layers, dropout=0.1, layer_norm_eps=None):
super().__init__()
self.n_out = n_out
layers = [FeedForwardLayer(d_in, d_hidden, 0.0, layer_norm_eps)]
layers += [FeedForwardLayer(d_hidden, d_hidden, dropout, layer_norm_eps)
for i in range(n_layers)]
self.layers = nn.Sequential(*layers)
self.output_layer = FeedForwardLayer(d_hidden, d_out*n_out, 0.0, None)
def forward(self, x):
x = self.layers(x)
x = self.output_layer(x)
x = torch.stack(torch.chunk(x, self.n_out, -1), 1)
return x
@dataclass
class MapperModelOutput(ModelOutput):
mapper_out: torch.FloatTensor = None
class MapperModel(PreTrainedModel):
config_class = MapperConfig
def __init__(self, config):
super().__init__(config)
self.mapper = Mapper(config.d_in, config.d_hidden, config.d_out, config.n_out,
config.n_layers, config.dropout, config.layer_norm_eps)
def forward(self, embedding, return_dict=True):
mapper_out = self.mapper(embedding)
if not return_dict:
return (mapper_out, )
return MapperModelOutput(mapper_out=mapper_out)