mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
import torch
from torch import nn
import torch.nn.functional as F
#FIX
import config as CFG
from modules import TextEncoder, ProjectionHead, ImageEncoder
class PoemTextModel(nn.Module):
"""
Model predicting poem and text embeddings, and their similarities.
...
Attributes:
-----------
poem_encoder : TextEncoder
encoder used for extracting poem embeddings
text_encoder : TextEncoder
encoder used for extracting text embeddings
poem_projection: ProjectionHead
projection head used for poem embeddings (projects poem encoder output to shared embedding space)
text_projection: ProjectionHead
projection head used for text embeddings (projects text encoder output to shared embedding space)
temperature: float
used to scale the dot similarities
Methods:
--------
forward(batch):
returns poem and text embeddings of batch
similarity_scores(batch):
computes dot similarities of a batch of text-poem pair
predict(batch):
predicts the most similar poem idx for each text (using previous methods)
calculate_loss(batch):
computes contrastive (cross entropy) loss for both poems and texts.
save_current():
saves current model's encoders (if trainable) and projection heads.
"""
def __init__(
self,
poem_encoder_pretrained,
text_encoder_pretrained,
temperature=CFG.temperature,
poem_embedding=CFG.poem_embedding,
text_embedding=CFG.text_embedding,
):
"""
Initializes model's submodules
Parameters:
-----------
poem_encoder_pretrained: bool
whether or not to load a pretrained poem encoder.
text_encoder_pretrained: bool
whether or not to load a pretrained text encoder.
temperature: float, optional
used to scale the dot similarities
poem_embedding: int, optional
dim of poem encoder's encoding output before projection
text_embedding: int, optional
dim of text encoder's encoding output before projection
"""
super().__init__()
self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable)
self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable)
self.poem_projection = ProjectionHead(embedding_dim=poem_embedding)
if CFG.poem_projection_load_path: # if provided, load projection weights from this path
self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
if CFG.text_projection_load_path: # if provided, load projection weights from this path
self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))
self.temperature = temperature
def forward(self, batch):
"""
returns poem and text embeddings of batch
Parameters:
-----------
batch: list of dict
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
Returns:
--------
poem and text embeddings of batch (each of shape (batch_size, projection_dim))
"""
beyts, texts = batch["beyt"], batch["text"]
# Getting Beyt and Text Features
poem_features = self.poem_encoder(
input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]
)
text_features = self.text_encoder(
input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
)
# Getting Beyt and Text Embeddings (with same dimension)
poem_embeddings = self.poem_projection(poem_features)
text_embeddings = self.text_projection(text_features)
return poem_embeddings, text_embeddings
def similarity_scores(self, batch):
"""
computes dot similarities of a batch of text-poem pair
Parameters:
-----------
batch: list of dict
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
Returns:
--------
dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size))
"""
# Getting Beyt and Text Embeddings (with same dimension)
poem_embeddings, text_embeddings = self.forward(batch)
# Normalizing embeddings
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
# Computing dot / cosine similarity of the normalized embeddings
dot_similarity = text_embeddings_n @ poem_embeddings_n.T
return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text
def predict(self, batch):
"""
predicts the most similar poem (idx) for each text (using previous methods)
Parameters:
-----------
batch: list of dict
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
Returns:
--------
index of poem predicted for each text (of shape (batch_size))
"""
dot_similarity = self.similarity_scores(batch)
# Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text
return torch.argmax(dot_similarity, dim=1)
def calculate_loss(self, poem_embeddings, text_embeddings):
"""
computes contrastive (cross entropy) loss for both poems and texts.
Parameters:
-----------
poem_embeddings: of shape (batch_size, projection_dim)
output embeddings of poem projection head
text_embeddings: of shape (batch_size, projection_dim)
output embeddings of text projection head
Returns:
--------
average of the loss computed from inputs
"""
# dot similarity of the embeddings scaled by temperature (logits)
logits = (text_embeddings @ poem_embeddings.T) / self.temperature
# computing targets for the cross entropy loss to compare with logits.
# each embedding's similarity is computed with itself and then added,
# scaled by the temperature parameter, and normalized into a probability distribution via a softmax
poems_similarity = poem_embeddings @ poem_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
# taking cross entropy loss in both dimensions: once for texts and once for poems
texts_loss = cross_entropy(logits, targets, reduction='none')
poems_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size)
return loss.mean()
def save_current(self):
"""
saves current model's encoders (if trainable) and projection heads.
"""
if CFG.text_encoder_trainable:
self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path)
if CFG.poem_encoder_trainable:
self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path)
torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path)
class CLIPModel(nn.Module):
"""
Model predicting poem/text and image embeddings, and their similarities.
...
Attributes:
-----------
encoder : TextEncoder
encoder used for extracting poem/text embeddings
image_encoder : ImageEncoder
encoder used for extracting image embeddings
text_projection: ProjectionHead
projection head used for poem/text embeddings (projects text encoder output to shared embedding space)
image_projection: ProjectionHead
projection head used for image embeddings (projects image encoder output to shared embedding space)
temperature: float
used to scale the dot similarities
Methods:
--------
forward(batch):
returns poem/text and image embeddings of batch
similarity_scores(batch):
computes dot similarities of a batch of text-image pair
predict(batch):
predicts the most similar poem/text idx for each image (using previous methods)
calculate_loss(batch):
computes contrastive (cross entropy) loss for both poems/texts and images.
save_current():
saves current model's encoders (if trainable) and projection heads.
"""
def __init__(
self,
image_encoder_pretrained,
text_encoder_pretrained,
text_projection_trainable,
temperature=CFG.temperature,
image_embedding=CFG.image_embedding,
text_embedding=CFG.text_embedding,
is_image_poem_pair=True
):
"""
Initializes model's submodules
Parameters:
-----------
image_encoder_pretrained: bool
whether or not to load a pretrained image encoder.
text_encoder_pretrained: bool
whether or not to load a pretrained text encoder.
text_projection_trainable: bool
whether or not to train text projection
(since the text projection is frozen in our trainings unlike other projections of models)
temperature: float, optional
used to scale the dot similarities
image_embedding: int, optional
dim of image encoder's encoding output before projection
text_embedding: int, optional
dim of text encoder's encoding output before projection
is_image_poem_pair: bool, optional
if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with.
else it's a text that needs the encoders dedicated to text.
"""
super().__init__()
# Loading the encoders and their projections using configs
self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable)
if is_image_poem_pair:
self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
if CFG.poem_projection_load_path:
self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))
else:
self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
if CFG.text_projection_load_path:
self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
if CFG.image_projection_load_path:
self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device))
if not text_projection_trainable:
for p in self.text_projection.parameters():
p.requires_grad = False
self.text_projection_trainable = text_projection_trainable
self.is_image_poem_pair = is_image_poem_pair
self.temperature = temperature
def forward(self, batch):
"""
returns image and text/poem embeddings of batch
Parameters:
-----------
batch: list of dict
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
with keys 'image' and 'text')
Returns:
--------
poem/text and image embeddings of batch (each of shape (batch_size, projection_dim))
"""
image, texts = batch["image"], batch["text"]
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.encoder(
input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
return image_embeddings, text_embeddings
def similarity_scores(self, batch):
"""
computes dot similarities of a batch of text/poem-image pair
Parameters:
-----------
batch: list of dict
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
with keys 'image' and 'text')
Returns:
--------
dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size))
"""
# Getting Image and Text Embeddings (with same dimension)
image_embeddings, text_embeddings = self.forward(batch)
# Normalizing embeddings
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
# Computing dot / cosine similarity of the normalized embeddings
dot_similarity = image_embeddings_n @ text_embeddings_n.T
return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image
def predict(self, batch):
"""
predicts the most similar poem/text (idx) for each image (using previous methods)
Parameters:
-----------
batch: list of dict
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
with keys 'image' and 'text')
Returns:
--------
index of poem/text predicted for each image (of shape (batch_size))
"""
dot_similarity = self.similarity_scores(batch)
# Getting argmax in first dimension of the dot-similarities
# to predict index of the most similar poem/text for each image
return torch.argmax(dot_similarity, dim=1)
def calculate_loss(self, image_embeddings, text_embeddings):
"""
computes contrastive (cross entropy) loss for both poems/texts and images.
Parameters:
-----------
image_embeddings: of shape (batch_size, projection_dim)
output embeddings of image projection head
text_embeddings: of shape (batch_size, projection_dim)
output embeddings of text projection head
Returns:
--------
average of the loss computed from inputs
"""
# dot similarity of the embeddings scaled by temperature (logits)
logits = (text_embeddings @ image_embeddings.T) / self.temperature
# computing targets for the cross entropy loss to compare with logits.
# each embedding's similarity is computed with itself and then averaged,
# scaled by the temperature parameter, and normalized into a probability distribution via a softmax
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
# taking cross entropy loss in both dimensions: once for texts and once for images
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size)
return loss.mean()
def save_current(self):
"""
saves current model's encoders and projection heads (if trainable).
"""
if self.is_image_poem_pair:
if CFG.poem_encoder_trainable:
self.encoder.model.save_pretrained(CFG.poem_encoder_save_path)
else:
if CFG.text_encoder_trainable:
self.encoder.model.save_pretrained(CFG.text_encoder_save_path)
if CFG.image_encoder_trainable:
torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path)
if self.text_projection_trainable:
torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path)
def cross_entropy(preds, targets, reduction='none'):
"""
Computes cross_entropy of logits and targets using their last dimension
Parameters:
-----------
preds: tensor/numpy array
logits
targets: tensor/ numpy array
reduction: str, optional
if set to "mean", return loss mean across all dimensions.
if set to "none", return loss computed using last dim.
Returns:
--------
loss or loss average
"""
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()