import torch from typing import Tuple from dataclasses import dataclass from transformers import PretrainedConfig, PreTrainedModel from .csd import CSD from .config import CSDConfig @dataclass class CSDOutput: image_embeds: torch.Tensor style_embeds: torch.Tensor content_embeds: torch.Tensor class CSDModel(PreTrainedModel): config_class = CSDConfig def __init__(self, config: CSDConfig) -> None: super(CSDModel, self).__init__(config) self.model = CSD( vit_input_resolution=config.vit_input_resolution, vit_patch_size=config.vit_patch_size, vit_width=config.vit_width, vit_layers=config.vit_layers, vit_heads=config.vit_heads, vit_output_dim=config.vit_output_dim, ) @torch.inference_mode() def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput: image_embeds, style_embeds, content_embeds = self.model(pixel_values) return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)