CSD / csd.py
vvmatorin's picture
fix: csd.py output order to match CSDOutput
881be7a verified
raw
history blame
1.32 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from clip.model import VisionTransformer
from typing import Tuple
class CSD(nn.Module):
def __init__(
self,
vit_input_resolution: int = 224,
vit_patch_size: int = 14,
vit_width: int = 1024,
vit_layers: int = 768,
vit_heads: int = 16,
vit_output_dim: int = 768,
) -> None:
super(CSD, self).__init__()
self.backbone = VisionTransformer(
input_resolution=vit_input_resolution,
patch_size=vit_patch_size,
width=vit_width,
layers=vit_layers,
heads=vit_heads,
output_dim=vit_output_dim,
)
self.last_layer_style = deepcopy(self.backbone.proj)
self.last_layer_content = deepcopy(self.backbone.proj)
self.backbone.proj = None
def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]:
features = self.backbone(pixel_values)
style_output = features @ self.last_layer_style
style_output = F.normalize(style_output, dim=1, p=2)
content_output = features @ self.last_layer_content
content_output = F.normalize(content_output, dim=1, p=2)
return features, style_output, content_output