Spaces:
Runtime error
Runtime error
File size: 6,878 Bytes
1bc9b9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import torch
from torch import nn
import timm
import config as CFG
class TextEncoder(nn.Module):
"""
Text/Poem encoder used in PoemTextModel and CLIPModel
...
Attributes:
-----------
model : a torch.nn.Module model
The image encoder model
Methods:
--------
forward(x)
returns model embeddings of x (batch of texts/poems) (of the CLS token)
__init__()
creates the encoder model using huggingface transformers,
also freezes the model if it's not trainable.
"""
def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable):
"""
creates the poem or text encoder model using transformers and loads weights from pretrained model if needed.
Also freezes the model if it's not trainable.
Parameters:
-----------
pretrained: bool
if pretrained=True, get pretrained model's weights. else create a fresh untrained model.
trainable: bool
if trainable=False, the model's weights will be frozen.
encoder_model: str
image encoder model name used as input to get the right model from configs.
encoder_pretrained_name: str
image encoder model to get weights from. (not used when pretrained=False)
"""
super().__init__()
if pretrained:
self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name)
else:
self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]())
for p in self.model.parameters():
p.requires_grad = trainable
# Using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
"""
forwards and calculates embeddings of the input using attention mask.
Parameters:
-----------
input_ids: input ids (output of tokenizer)
attention masks: input masks (for example for padding, pad tokens will be masked)
Returns:
--------
the embedding of the CLS (or target) token of the encoder's last hidden state
"""
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module):
"""
Projection head used to project embeddings from each encoder to a shared embedding space
...
Attributes:
-----------
projection : torch.nn.Linear
The main Dense projection (from encoder's embedding dim to shared embedding projection dim)
gelu: torch.nn.GELU
activation function
fc: torch.nn.Linear
a dense layer after projection (projection_dim to projection_dim)
dropout: torch.nn.Dropout
dropout after fc
layer_norm: torch.nn.LayerNorm
layer norm after dropout
Methods:
--------
forward(x)
returns projection embeddings from x (encoder output embeddings)
__init__()
creates the projection head
"""
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
"""
Creates the projection head used after an encoder.
Parameters:
-----------
embedding_dim: int
dimension of the output embeddings of the encoder.
projection_dim: int, optional
dimension to project embeddings to.
dropout: float
fraction of the output of fc layer to be zeroed.
"""
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
"""
Forwards and calculates projected embeddings from encoder embeddings.
Parameters:
-----------
x: input (of shape (batch_size, embedding_dim))
the output embedding of this projection head's encoder
Returns:
--------
the embeddings in a shared embedding space (of shape (batch_size, projection_dim))
"""
projected = self.projection(x) #main projection layer
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
# the projected outputs are added to x as a residual connection
x = x + projected
x = self.layer_norm(x)
return x
class ImageEncoder(nn.Module):
"""
Image encoder used in CLIPModel
...
Attributes:
-----------
model : a torch.nn.Module model from timm (pytorch-image-models)
The image encoder model
Methods:
--------
forward(x)
returns model embeddings of x (batch of images)
__init__()
creates the encoder model using timm and loads fine-tuned model's state dict if needed.
also freezes the model if it's not trainable.
"""
def __init__(
self, pretrained, trainable, model_name=CFG.image_encoder_model
):
"""
creates the encoder model using timm and loads fine-tuned model's state dict if needed.
Also freezes the model if it's not trainable.
Parameters:
-----------
pretrained: bool
if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path).
else create a fresh untrained model.
trainable: bool
if trainable=False, the model's weights will be frozen.
model_name: str
image encoder model name used as input to timm.create_model.
"""
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
if pretrained and CFG.image_encoder_weights_load_path:
self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device))
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
"""
forwards and calculates embeddings of the input.
Parameters:
-----------
x: input (batch of transformed images)
Returns:
--------
embeddings of the model for the input (of shape (batch_size, image_embedding))
"""
return self.model(x)
|