File size: 5,072 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import torch
import torch.nn as nn
import os
from omegaconf import OmegaConf
from starvector.model.image_encoder.clip_model import convert_weights_to_precision
from starvector.data.util import ImageTrainProcessor

class ImageEncoder(nn.Module):
    def __init__(self, config, **kwargs):
        super(ImageEncoder, self).__init__()
        
        image_size = config.image_size
        torch_dtype = kwargs.get('model_precision', config.torch_dtype)
        self.image_encoder_type = config.image_encoder_type
        if self.image_encoder_type == 'clip':
            self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size)
            convert_weights_to_precision(self, torch_dtype)
            self.processor = ImageTrainProcessor(size=config.image_size)

        elif self.image_encoder_type == 'vqgan':
            self.visual_encoder = self.build_vqgan_encoder()
            self.ln_vision = None
            self.processor = ImageTrainProcessor(size=config.image_size)

        elif self.image_encoder_type == 'convnext':
            self.visual_encoder = self.build_vqgan_encoder()
            self.ln_vision = None
            self.processor = ImageTrainProcessor(size=config.image_size)

        elif 'siglip' in self.image_encoder_type:
            if self.image_encoder_type == 'siglip_512':
                model_name = "google/siglip-base-patch16-512"
            elif self.image_encoder_type == 'siglip_384':
                model_name = "google/siglip-large-patch16-384"
            elif self.image_encoder_type == 'siglip_256':
                model_name = "google/siglip-base-patch16-256"
                
            from transformers import AutoProcessor, AutoModel

            self.visual_encoder = AutoModel.from_pretrained(
                model_name, torch_dtype = torch_dtype
            ).vision_model

            self.processor = AutoProcessor.from_pretrained(
                model_name, torch_dtype = torch_dtype
            )

    def build_clip_encoder(self, image_size):
        from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm
        visual_encoder = VisionTransformer(
            input_resolution=image_size,
            patch_size=14,
            width=1024,
            layers=23,
            heads=16,
            use_grad_checkpointing=False)

        ln_vision = LayerNorm(visual_encoder.num_features)
        return visual_encoder, ln_vision
    
    def build_vqgan_encoder(self):
        from taming.modules.diffusionmodules.model import Encoder
        VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md
        vqgan_chkp_path =  VQGAN_CHECKPOINT
        files_in_directory = os.listdir(vqgan_chkp_path + '/configs')
        vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0]
        vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file))
        visual_encoder = Encoder(**vqgan_config.model.params.ddconfig)

        # Load checkpoint weights
        checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict']

        # Create a new state_dict with modified keys
        new_state_dict = {}
        for key, value in checkpoint.items():
            if key.startswith('encoder.'):
                new_key = key[len('encoder.'):]
                new_state_dict[new_key] = value

        # Load weights
        visual_encoder.load_state_dict(new_state_dict)
        return visual_encoder
    
    def build_convnext_encoder(self):
        import open_clip
        model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
        return model.visual

    def forward(self, image):
        if self.image_encoder_type == 'clip':
            embeds = self.visual_encoder(image)
            out = self.ln_vision(embeds)
        elif self.image_encoder_type == 'open-clip':
            out = self.visual_encoder(image)[1]
            out = self.ln_vision(out)
        elif self.image_encoder_type == 'vqgan':
            out = self.visual_encoder(image)
            size = out.size()
            out = out.view(size[0], size[1], -1)
            out = out.permute(0, 2, 1)
        elif self.image_encoder_type == 'convnext':
            out = self.visual_encoder.trunk.forward_features(image)
            size = out.size()
            out = out.view(size[0], size[1], -1)
            out = out.permute(0, 2, 1)
        elif 'siglip' in self.image_encoder_type:
            out = self.visual_encoder(image)["last_hidden_state"]
        return out

    def process_images(self, images):
        if self.image_encoder_type == 'clip':
            res = []
            for image in images:
                res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W
            return res
        else:
            return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0)