Upload model
Browse files- config.json +17 -0
- config.py +22 -0
- csd.py +44 -0
- model.py +36 -0
- model.safetensors +3 -0
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"CSDModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "config.CSDConfig",
|
7 |
+
"AutoModel": "model.CSDModel"
|
8 |
+
},
|
9 |
+
"torch_dtype": "float32",
|
10 |
+
"transformers_version": "4.37.2",
|
11 |
+
"vit_heads": 16,
|
12 |
+
"vit_input_resolution": 224,
|
13 |
+
"vit_layers": 24,
|
14 |
+
"vit_output_dim": 768,
|
15 |
+
"vit_patch_size": 14,
|
16 |
+
"vit_width": 1024
|
17 |
+
}
|
config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class CSDConfig(PretrainedConfig):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
vit_input_resolution: int = 224,
|
8 |
+
vit_patch_size: int = 14,
|
9 |
+
vit_width: int = 1024,
|
10 |
+
vit_layers: int = 24,
|
11 |
+
vit_heads: int = 16,
|
12 |
+
vit_output_dim: int = 768,
|
13 |
+
**kwargs
|
14 |
+
) -> None:
|
15 |
+
super(CSDConfig, self).__init__(**kwargs)
|
16 |
+
|
17 |
+
self.vit_input_resolution = vit_input_resolution
|
18 |
+
self.vit_patch_size = vit_patch_size
|
19 |
+
self.vit_width = vit_width
|
20 |
+
self.vit_layers = vit_layers
|
21 |
+
self.vit_heads = vit_heads
|
22 |
+
self.vit_output_dim = vit_output_dim
|
csd.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from clip.model import VisionTransformer
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
|
10 |
+
class CSD(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vit_input_resolution: int = 224,
|
14 |
+
vit_patch_size: int = 14,
|
15 |
+
vit_width: int = 1024,
|
16 |
+
vit_layers: int = 768,
|
17 |
+
vit_heads: int = 16,
|
18 |
+
vit_output_dim: int = 768,
|
19 |
+
) -> None:
|
20 |
+
super(CSD, self).__init__()
|
21 |
+
|
22 |
+
self.backbone = VisionTransformer(
|
23 |
+
input_resolution=vit_input_resolution,
|
24 |
+
patch_size=vit_patch_size,
|
25 |
+
width=vit_width,
|
26 |
+
layers=vit_layers,
|
27 |
+
heads=vit_heads,
|
28 |
+
output_dim=vit_output_dim,
|
29 |
+
)
|
30 |
+
|
31 |
+
self.last_layer_style = deepcopy(self.backbone.proj)
|
32 |
+
self.last_layer_content = deepcopy(self.backbone.proj)
|
33 |
+
self.backbone.proj = None
|
34 |
+
|
35 |
+
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]:
|
36 |
+
features = self.backbone(pixel_values)
|
37 |
+
|
38 |
+
style_output = features @ self.last_layer_style
|
39 |
+
style_output = F.normalize(style_output, dim=1, p=2)
|
40 |
+
|
41 |
+
content_output = features @ self.last_layer_content
|
42 |
+
content_output = F.normalize(content_output, dim=1, p=2)
|
43 |
+
|
44 |
+
return features, content_output, style_output
|
model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import Tuple
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
6 |
+
|
7 |
+
from .csd import CSD
|
8 |
+
from .config import CSDConfig
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class CSDOutput:
|
13 |
+
image_embeds: torch.Tensor
|
14 |
+
style_embeds: torch.Tensor
|
15 |
+
content_embeds: torch.Tensor
|
16 |
+
|
17 |
+
|
18 |
+
class CSDModel(PreTrainedModel):
|
19 |
+
config_class = CSDConfig
|
20 |
+
|
21 |
+
def __init__(self, config: CSDConfig) -> None:
|
22 |
+
super(CSDModel, self).__init__(config)
|
23 |
+
|
24 |
+
self.model = CSD(
|
25 |
+
vit_input_resolution=config.vit_input_resolution,
|
26 |
+
vit_patch_size=config.vit_patch_size,
|
27 |
+
vit_width=config.vit_width,
|
28 |
+
vit_layers=config.vit_layers,
|
29 |
+
vit_heads=config.vit_heads,
|
30 |
+
vit_output_dim=config.vit_output_dim,
|
31 |
+
)
|
32 |
+
|
33 |
+
@torch.inference_mode()
|
34 |
+
def forward(self, pixel_values: torch.Tensor) -> CSDOutput:
|
35 |
+
image_embeds, style_embeds, content_embeds = self.model(pixel_values)
|
36 |
+
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4edeb72ee261d99700b654ec40d89484ed3ff02c49a277a63668897a9261914
|
3 |
+
size 1219048024
|