mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
raw
history blame
6.88 kB
import torch
from torch import nn
import timm
import config as CFG
class TextEncoder(nn.Module):
"""
Text/Poem encoder used in PoemTextModel and CLIPModel
...
Attributes:
-----------
model : a torch.nn.Module model
The image encoder model
Methods:
--------
forward(x)
returns model embeddings of x (batch of texts/poems) (of the CLS token)
__init__()
creates the encoder model using huggingface transformers,
also freezes the model if it's not trainable.
"""
def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable):
"""
creates the poem or text encoder model using transformers and loads weights from pretrained model if needed.
Also freezes the model if it's not trainable.
Parameters:
-----------
pretrained: bool
if pretrained=True, get pretrained model's weights. else create a fresh untrained model.
trainable: bool
if trainable=False, the model's weights will be frozen.
encoder_model: str
image encoder model name used as input to get the right model from configs.
encoder_pretrained_name: str
image encoder model to get weights from. (not used when pretrained=False)
"""
super().__init__()
if pretrained:
self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name)
else:
self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]())
for p in self.model.parameters():
p.requires_grad = trainable
# Using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
"""
forwards and calculates embeddings of the input using attention mask.
Parameters:
-----------
input_ids: input ids (output of tokenizer)
attention masks: input masks (for example for padding, pad tokens will be masked)
Returns:
--------
the embedding of the CLS (or target) token of the encoder's last hidden state
"""
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module):
"""
Projection head used to project embeddings from each encoder to a shared embedding space
...
Attributes:
-----------
projection : torch.nn.Linear
The main Dense projection (from encoder's embedding dim to shared embedding projection dim)
gelu: torch.nn.GELU
activation function
fc: torch.nn.Linear
a dense layer after projection (projection_dim to projection_dim)
dropout: torch.nn.Dropout
dropout after fc
layer_norm: torch.nn.LayerNorm
layer norm after dropout
Methods:
--------
forward(x)
returns projection embeddings from x (encoder output embeddings)
__init__()
creates the projection head
"""
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
"""
Creates the projection head used after an encoder.
Parameters:
-----------
embedding_dim: int
dimension of the output embeddings of the encoder.
projection_dim: int, optional
dimension to project embeddings to.
dropout: float
fraction of the output of fc layer to be zeroed.
"""
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
"""
Forwards and calculates projected embeddings from encoder embeddings.
Parameters:
-----------
x: input (of shape (batch_size, embedding_dim))
the output embedding of this projection head's encoder
Returns:
--------
the embeddings in a shared embedding space (of shape (batch_size, projection_dim))
"""
projected = self.projection(x) #main projection layer
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
# the projected outputs are added to x as a residual connection
x = x + projected
x = self.layer_norm(x)
return x
class ImageEncoder(nn.Module):
"""
Image encoder used in CLIPModel
...
Attributes:
-----------
model : a torch.nn.Module model from timm (pytorch-image-models)
The image encoder model
Methods:
--------
forward(x)
returns model embeddings of x (batch of images)
__init__()
creates the encoder model using timm and loads fine-tuned model's state dict if needed.
also freezes the model if it's not trainable.
"""
def __init__(
self, pretrained, trainable, model_name=CFG.image_encoder_model
):
"""
creates the encoder model using timm and loads fine-tuned model's state dict if needed.
Also freezes the model if it's not trainable.
Parameters:
-----------
pretrained: bool
if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path).
else create a fresh untrained model.
trainable: bool
if trainable=False, the model's weights will be frozen.
model_name: str
image encoder model name used as input to timm.create_model.
"""
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
if pretrained and CFG.image_encoder_weights_load_path:
self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device))
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
"""
forwards and calculates embeddings of the input.
Parameters:
-----------
x: input (batch of transformed images)
Returns:
--------
embeddings of the model for the input (of shape (batch_size, image_embedding))
"""
return self.model(x)