File size: 2,647 Bytes
73bde56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""

image_proj_model.py



This module defines the ImageProjModel class, which is responsible for

projecting image embeddings into a different dimensional space. The model 

leverages a linear transformation followed by a layer normalization to 

reshape and normalize the input image embeddings for further processing in 

cross-attention mechanisms or other downstream tasks.



Classes:

    ImageProjModel



Dependencies:

    torch

    diffusers.ModelMixin



"""

import torch
from diffusers import ModelMixin


class ImageProjModel(ModelMixin):
    """

    ImageProjModel is a class that projects image embeddings into a different

    dimensional space. It inherits from ModelMixin, providing additional functionalities

    specific to image projection.



    Attributes:

        cross_attention_dim (int): The dimension of the cross attention.

        clip_embeddings_dim (int): The dimension of the CLIP embeddings.

        clip_extra_context_tokens (int): The number of extra context tokens in CLIP.



    Methods:

        forward(image_embeds): Forward pass of the ImageProjModel, which takes in image

        embeddings and returns the projected tokens.



    """

    def __init__(

        self,

        cross_attention_dim=1024,

        clip_embeddings_dim=1024,

        clip_extra_context_tokens=4,

    ):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(
            clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
        )
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        """

        Forward pass of the ImageProjModel, which takes in image embeddings and returns the

        projected tokens after reshaping and normalization.



        Args:

            image_embeds (torch.Tensor): The input image embeddings, with shape

            batch_size x num_image_tokens x clip_embeddings_dim.



        Returns:

            clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping

            and normalization, with shape batch_size x (clip_extra_context_tokens *

            cross_attention_dim).



        """
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens