|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from copy import deepcopy |
|
from clip.model import VisionTransformer |
|
from typing import Tuple |
|
|
|
|
|
class CSD(nn.Module): |
|
def __init__( |
|
self, |
|
vit_input_resolution: int = 224, |
|
vit_patch_size: int = 14, |
|
vit_width: int = 1024, |
|
vit_layers: int = 768, |
|
vit_heads: int = 16, |
|
vit_output_dim: int = 768, |
|
) -> None: |
|
super(CSD, self).__init__() |
|
|
|
self.backbone = VisionTransformer( |
|
input_resolution=vit_input_resolution, |
|
patch_size=vit_patch_size, |
|
width=vit_width, |
|
layers=vit_layers, |
|
heads=vit_heads, |
|
output_dim=vit_output_dim, |
|
) |
|
|
|
self.last_layer_style = deepcopy(self.backbone.proj) |
|
self.last_layer_content = deepcopy(self.backbone.proj) |
|
self.backbone.proj = None |
|
|
|
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]: |
|
features = self.backbone(pixel_values) |
|
|
|
style_output = features @ self.last_layer_style |
|
style_output = F.normalize(style_output, dim=1, p=2) |
|
|
|
content_output = features @ self.last_layer_content |
|
content_output = F.normalize(content_output, dim=1, p=2) |
|
|
|
return features, style_output, content_output |
|
|