vvmatorin commited on
Commit
520a6ec
·
verified ·
1 Parent(s): 7b765c3

Upload model

Browse files
Files changed (5) hide show
  1. config.json +17 -0
  2. config.py +22 -0
  3. csd.py +44 -0
  4. model.py +36 -0
  5. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CSDModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.CSDConfig",
7
+ "AutoModel": "model.CSDModel"
8
+ },
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.37.2",
11
+ "vit_heads": 16,
12
+ "vit_input_resolution": 224,
13
+ "vit_layers": 24,
14
+ "vit_output_dim": 768,
15
+ "vit_patch_size": 14,
16
+ "vit_width": 1024
17
+ }
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CSDConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vit_input_resolution: int = 224,
8
+ vit_patch_size: int = 14,
9
+ vit_width: int = 1024,
10
+ vit_layers: int = 24,
11
+ vit_heads: int = 16,
12
+ vit_output_dim: int = 768,
13
+ **kwargs
14
+ ) -> None:
15
+ super(CSDConfig, self).__init__(**kwargs)
16
+
17
+ self.vit_input_resolution = vit_input_resolution
18
+ self.vit_patch_size = vit_patch_size
19
+ self.vit_width = vit_width
20
+ self.vit_layers = vit_layers
21
+ self.vit_heads = vit_heads
22
+ self.vit_output_dim = vit_output_dim
csd.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from copy import deepcopy
6
+ from clip.model import VisionTransformer
7
+ from typing import Tuple
8
+
9
+
10
+ class CSD(nn.Module):
11
+ def __init__(
12
+ self,
13
+ vit_input_resolution: int = 224,
14
+ vit_patch_size: int = 14,
15
+ vit_width: int = 1024,
16
+ vit_layers: int = 768,
17
+ vit_heads: int = 16,
18
+ vit_output_dim: int = 768,
19
+ ) -> None:
20
+ super(CSD, self).__init__()
21
+
22
+ self.backbone = VisionTransformer(
23
+ input_resolution=vit_input_resolution,
24
+ patch_size=vit_patch_size,
25
+ width=vit_width,
26
+ layers=vit_layers,
27
+ heads=vit_heads,
28
+ output_dim=vit_output_dim,
29
+ )
30
+
31
+ self.last_layer_style = deepcopy(self.backbone.proj)
32
+ self.last_layer_content = deepcopy(self.backbone.proj)
33
+ self.backbone.proj = None
34
+
35
+ def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]:
36
+ features = self.backbone(pixel_values)
37
+
38
+ style_output = features @ self.last_layer_style
39
+ style_output = F.normalize(style_output, dim=1, p=2)
40
+
41
+ content_output = features @ self.last_layer_content
42
+ content_output = F.normalize(content_output, dim=1, p=2)
43
+
44
+ return features, content_output, style_output
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+
7
+ from .csd import CSD
8
+ from .config import CSDConfig
9
+
10
+
11
+ @dataclass
12
+ class CSDOutput:
13
+ image_embeds: torch.Tensor
14
+ style_embeds: torch.Tensor
15
+ content_embeds: torch.Tensor
16
+
17
+
18
+ class CSDModel(PreTrainedModel):
19
+ config_class = CSDConfig
20
+
21
+ def __init__(self, config: CSDConfig) -> None:
22
+ super(CSDModel, self).__init__(config)
23
+
24
+ self.model = CSD(
25
+ vit_input_resolution=config.vit_input_resolution,
26
+ vit_patch_size=config.vit_patch_size,
27
+ vit_width=config.vit_width,
28
+ vit_layers=config.vit_layers,
29
+ vit_heads=config.vit_heads,
30
+ vit_output_dim=config.vit_output_dim,
31
+ )
32
+
33
+ @torch.inference_mode()
34
+ def forward(self, pixel_values: torch.Tensor) -> CSDOutput:
35
+ image_embeds, style_embeds, content_embeds = self.model(pixel_values)
36
+ return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4edeb72ee261d99700b654ec40d89484ed3ff02c49a277a63668897a9261914
3
+ size 1219048024