File size: 18,565 Bytes
1bc9b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import torch
from torch import nn
import torch.nn.functional as F

#FIX
import config as CFG
from modules import TextEncoder, ProjectionHead, ImageEncoder


class PoemTextModel(nn.Module):
    """
    Model predicting poem and text embeddings, and their similarities.
    ...
    Attributes:
    -----------
        poem_encoder : TextEncoder
            encoder used for extracting poem embeddings
        text_encoder : TextEncoder
            encoder used for extracting text embeddings
        poem_projection: ProjectionHead
            projection head used for poem embeddings (projects poem encoder output to shared embedding space)
        text_projection: ProjectionHead
            projection head used for text embeddings (projects text encoder output to shared embedding space)
        temperature: float
            used to scale the dot similarities
    
    Methods:
    --------
        forward(batch):
            returns poem and text embeddings of batch
        similarity_scores(batch):
            computes dot similarities of a batch of text-poem pair
        predict(batch):
            predicts the most similar poem idx for each text (using previous methods)
        calculate_loss(batch):
            computes contrastive (cross entropy) loss for both poems and texts.
        save_current():
            saves current model's encoders (if trainable) and projection heads.
    """
    def __init__(
        self,
        poem_encoder_pretrained,
        text_encoder_pretrained,
        temperature=CFG.temperature,
        poem_embedding=CFG.poem_embedding,
        text_embedding=CFG.text_embedding,
    ):
        """
        Initializes model's submodules
            Parameters:
            -----------
                poem_encoder_pretrained: bool
                    whether or not to load a pretrained poem encoder.
                text_encoder_pretrained: bool
                    whether or not to load a pretrained text encoder.
                temperature: float, optional
                    used to scale the dot similarities
                poem_embedding: int, optional
                    dim of poem encoder's encoding output before projection
                text_embedding: int, optional
                    dim of text encoder's encoding output before projection
        """
        super().__init__()
        self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable)
        self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable)

        self.poem_projection = ProjectionHead(embedding_dim=poem_embedding)
        if CFG.poem_projection_load_path: # if provided, load projection weights from this path
            self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))

        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        if CFG.text_projection_load_path: # if provided, load projection weights from this path
            self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))

        self.temperature = temperature

    def forward(self, batch):
        """
        returns poem and text embeddings of batch

            Parameters:
            -----------
            batch: list of dict
                input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')

            Returns:
            --------
            poem and text embeddings of batch (each of shape (batch_size, projection_dim))
        """
        beyts, texts = batch["beyt"], batch["text"]
        # Getting Beyt and Text Features
        poem_features = self.poem_encoder(
            input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]
        )
        text_features = self.text_encoder(
            input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
        )
        # Getting Beyt and Text Embeddings (with same dimension)
        poem_embeddings = self.poem_projection(poem_features)
        text_embeddings = self.text_projection(text_features)
        
        return poem_embeddings, text_embeddings
    
    def similarity_scores(self, batch):
        """
        computes dot similarities of a batch of text-poem pair

            Parameters:
            -----------
            batch: list of dict
                input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')

            Returns:
            --------
            dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size))
        """
        # Getting Beyt and Text Embeddings (with same dimension)
        poem_embeddings, text_embeddings = self.forward(batch)
        # Normalizing embeddings
        poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        # Computing dot / cosine similarity of the normalized embeddings
        dot_similarity = text_embeddings_n @ poem_embeddings_n.T
        return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text

    def predict(self, batch):
        """
        predicts the most similar poem (idx) for each text (using previous methods)

            Parameters:
            -----------
            batch: list of dict
                input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')

            Returns:
            --------
            index of poem predicted for each text (of shape (batch_size))
        """
        dot_similarity = self.similarity_scores(batch)
        # Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text
        return torch.argmax(dot_similarity, dim=1)

    def calculate_loss(self, poem_embeddings, text_embeddings):
        """
        computes contrastive (cross entropy) loss for both poems and texts.

            Parameters:
            -----------
            poem_embeddings: of shape (batch_size, projection_dim)
                output embeddings of poem projection head
            text_embeddings: of shape (batch_size, projection_dim)
                output embeddings of text projection head

            Returns:
            --------
            average of the loss computed from inputs 
        """
        # dot similarity of the embeddings scaled by temperature (logits)
        logits = (text_embeddings @ poem_embeddings.T) / self.temperature
        # computing targets for the cross entropy loss to compare with logits.
        # each embedding's similarity is computed with itself and then added, 
        # scaled by the temperature parameter, and normalized into a probability distribution via a softmax
        poems_similarity = poem_embeddings @ poem_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        # taking cross entropy loss in both dimensions: once for texts and once for poems
        texts_loss = cross_entropy(logits, targets, reduction='none')
        poems_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size)
        return loss.mean()
    
    def save_current(self):
        """
        saves current model's encoders (if trainable) and projection heads.
        """
        if CFG.text_encoder_trainable:
            self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path)
        if CFG.poem_encoder_trainable:
            self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path)
        torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
        torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path)

class CLIPModel(nn.Module):
    """
    Model predicting poem/text and image embeddings, and their similarities.
    ...
    Attributes:
    -----------
        encoder : TextEncoder
            encoder used for extracting poem/text embeddings
        image_encoder : ImageEncoder
            encoder used for extracting image embeddings
        text_projection: ProjectionHead
            projection head used for poem/text embeddings (projects text encoder output to shared embedding space)
        image_projection: ProjectionHead
            projection head used for image embeddings (projects image encoder output to shared embedding space)
        temperature: float
            used to scale the dot similarities
    
    Methods:
    --------
        forward(batch):
            returns poem/text and image embeddings of batch
        similarity_scores(batch):
            computes dot similarities of a batch of text-image pair
        predict(batch):
            predicts the most similar poem/text idx for each image (using previous methods)
        calculate_loss(batch):
            computes contrastive (cross entropy) loss for both poems/texts and images.
        save_current():
            saves current model's encoders (if trainable) and projection heads.
    """
    def __init__(
        self,
        image_encoder_pretrained,
        text_encoder_pretrained,
        text_projection_trainable,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
        is_image_poem_pair=True
    ):
        """
        Initializes model's submodules
            Parameters:
            -----------
                image_encoder_pretrained: bool
                    whether or not to load a pretrained image encoder.
                text_encoder_pretrained: bool
                    whether or not to load a pretrained text encoder.
                text_projection_trainable: bool
                    whether or not to train text projection 
                    (since the text projection is frozen in our trainings unlike other projections of models)
                temperature: float, optional
                    used to scale the dot similarities
                image_embedding: int, optional
                    dim of image encoder's encoding output before projection
                text_embedding: int, optional
                    dim of text encoder's encoding output before projection
                is_image_poem_pair: bool, optional
                    if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with.
                    else it's a text that needs the encoders dedicated to text.
        """
        super().__init__()
        # Loading the encoders and their projections using configs
        self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable)

        if is_image_poem_pair:
            self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable)
            self.text_projection = ProjectionHead(embedding_dim=text_embedding)
            if CFG.poem_projection_load_path:
                self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))
        else:
            self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable)
            self.text_projection = ProjectionHead(embedding_dim=text_embedding)
            if CFG.text_projection_load_path:
                self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))

        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        if CFG.image_projection_load_path:
            self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device))

        if not text_projection_trainable:
            for p in self.text_projection.parameters():
                p.requires_grad = False

        self.text_projection_trainable = text_projection_trainable
        self.is_image_poem_pair = is_image_poem_pair
        self.temperature = temperature

    def forward(self, batch):
        """
        returns image and text/poem embeddings of batch

            Parameters:
            -----------
            batch: list of dict
                input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) 
                with keys 'image' and 'text')

            Returns:
            --------
            poem/text and image embeddings of batch (each of shape (batch_size, projection_dim))
        """
        image, texts = batch["image"], batch["text"]
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.encoder(
            input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)
        
        return image_embeddings, text_embeddings
    
    def similarity_scores(self, batch):
        """
        computes dot similarities of a batch of text/poem-image pair

            Parameters:
            -----------
            batch: list of dict
                input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) 
                with keys 'image' and 'text')

            Returns:
            --------
            dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size))
        """
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings, text_embeddings = self.forward(batch)
        # Normalizing embeddings
        image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        # Computing dot / cosine similarity of the normalized embeddings
        dot_similarity = image_embeddings_n @ text_embeddings_n.T
        return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image

    def predict(self, batch):
        """
        predicts the most similar poem/text (idx) for each image (using previous methods)

            Parameters:
            -----------
            batch: list of dict
                input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) 
                with keys 'image' and 'text')

            Returns:
            --------
            index of poem/text predicted for each image (of shape (batch_size))
        """
        dot_similarity = self.similarity_scores(batch)
        # Getting argmax in first dimension of the dot-similarities 
        # to predict index of the most similar poem/text for each image
        return torch.argmax(dot_similarity, dim=1)

    def calculate_loss(self, image_embeddings, text_embeddings):
        """
        computes contrastive (cross entropy) loss for both poems/texts and images.

            Parameters:
            -----------
            image_embeddings: of shape (batch_size, projection_dim)
                output embeddings of image projection head
            text_embeddings: of shape (batch_size, projection_dim)
                output embeddings of text projection head

            Returns:
            --------
            average of the loss computed from inputs
        """
        # dot similarity of the embeddings scaled by temperature (logits)
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        # computing targets for the cross entropy loss to compare with logits.
        # each embedding's similarity is computed with itself and then averaged, 
        # scaled by the temperature parameter, and normalized into a probability distribution via a softmax
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        # taking cross entropy loss in both dimensions: once for texts and once for images
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0  # average of losses. shape: (batch_size)
        return loss.mean()
    
    def save_current(self):
        """
        saves current model's encoders and projection heads (if trainable).
        """
        if self.is_image_poem_pair:
            if CFG.poem_encoder_trainable:
                self.encoder.model.save_pretrained(CFG.poem_encoder_save_path)
        else:
            if CFG.text_encoder_trainable:
                self.encoder.model.save_pretrained(CFG.text_encoder_save_path)
        if CFG.image_encoder_trainable:
            torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path)
        if self.text_projection_trainable:
            torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
        torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path) 

def cross_entropy(preds, targets, reduction='none'):
    """
    Computes cross_entropy of logits and targets using their last dimension

        Parameters:
        -----------
            preds: tensor/numpy array
                logits
            targets: tensor/ numpy array
            reduction: str, optional
                if set to "mean", return loss mean across all dimensions.
                if set to "none", return loss computed using last dim.
        
        Returns:
        --------
            loss or loss average
    """
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()