File size: 3,544 Bytes
570db9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c38041d
570db9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Prepare the models to speed up loading them later
import torch
from torch import nn, Tensor
import os
from tqdm import tqdm
import json

from .utils import load


model_name_map = {
    "ViT-B/16": "vit_b_16",
    "ViT-L/14": "vit_l_14",
}


class CLIPTextEncoderTemp(nn.Module):
    def __init__(
        self,
        clip: nn.Module,
    ) -> None:
        super().__init__()
        self.context_length = clip.context_length
        self.vocab_size = clip.vocab_size
        self.dtype = clip.dtype
        self.token_embedding = clip.token_embedding
        self.positional_embedding = clip.positional_embedding
        self.transformer = clip.transformer
        self.ln_final = clip.ln_final
        self.text_projection = clip.text_projection

    def forward(self, text: Tensor) -> None:
        pass


def prepare() -> None:
    print("Preparing CLIP models...")
    curr_dir = os.path.dirname(os.path.abspath(__file__))
    weight_dir = os.path.join(curr_dir, "weights")
    config_dir = os.path.join(curr_dir, "configs")
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(config_dir, exist_ok=True)
    device = torch.device("cpu")

    for model_name in tqdm(["ViT-B/16", "ViT-L/14"]):
        model = load(model_name, device=device).to(device)
        image_encoder = model.visual.to(device)
        text_encoder = CLIPTextEncoderTemp(model).to(device)
        torch.save(model.state_dict(), os.path.join(weight_dir, f"clip_{model_name_map[model_name]}.pth"))
        torch.save(image_encoder.state_dict(), os.path.join(weight_dir, f"clip_image_encoder_{model_name_map[model_name]}.pth"))
        torch.save(text_encoder.state_dict(), os.path.join(weight_dir, f"clip_text_encoder_{model_name_map[model_name]}.pth"))
        model_config = {
            "embed_dim": model.embed_dim,
            # vision
            "image_resolution": model.image_resolution,
            "vision_layers": model.vision_layers,
            "vision_width": model.vision_width,
            "vision_patch_size": model.vision_patch_size,
            # text
            "context_length": model.context_length,
            "vocab_size": model.vocab_size,
            "transformer_width": model.transformer_width,
            "transformer_heads": model.transformer_heads,
            "transformer_layers": model.transformer_layers,
        }
        image_encoder_config = {
            "embed_dim": model.embed_dim,
            "image_resolution": model.image_resolution,
            "vision_layers": model.vision_layers,
            "vision_width": model.vision_width,
            "vision_patch_size": model.vision_patch_size,
            "vision_heads": model.vision_heads,
        }
        text_encoder_config = {
            "embed_dim": model.embed_dim,
            "context_length": model.context_length,
            "vocab_size": model.vocab_size,
            "transformer_width": model.transformer_width,
            "transformer_heads": model.transformer_heads,
            "transformer_layers": model.transformer_layers,
        }
        with open(os.path.join(config_dir, f"clip_{model_name_map[model_name]}.json"), "w") as f:
            json.dump(model_config, f, indent=4)
        with open(os.path.join(config_dir, f"clip_image_encoder_{model_name_map[model_name]}.json"), "w") as f:
            json.dump(image_encoder_config, f, indent=4)
        with open(os.path.join(config_dir, f"clip_text_encoder_{model_name_map[model_name]}.json"), "w") as f:
            json.dump(text_encoder_config, f, indent=4)
    print("Done!")