import torch import torch.nn as nn import torch.nn.functional as F from copy import deepcopy from clip.model import VisionTransformer from typing import Tuple class CSD(nn.Module): def __init__( self, vit_input_resolution: int = 224, vit_patch_size: int = 14, vit_width: int = 1024, vit_layers: int = 768, vit_heads: int = 16, vit_output_dim: int = 768, ) -> None: super(CSD, self).__init__() self.backbone = VisionTransformer( input_resolution=vit_input_resolution, patch_size=vit_patch_size, width=vit_width, layers=vit_layers, heads=vit_heads, output_dim=vit_output_dim, ) self.last_layer_style = deepcopy(self.backbone.proj) self.last_layer_content = deepcopy(self.backbone.proj) self.backbone.proj = None def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]: features = self.backbone(pixel_values) style_output = features @ self.last_layer_style style_output = F.normalize(style_output, dim=1, p=2) content_output = features @ self.last_layer_content content_output = F.normalize(content_output, dim=1, p=2) return features, style_output, content_output