Measuring Style Similarity in Diffusion Models

Cloned from learn2phoenix/CSD.

Their model (csd-vit-l.pth) downloaded from their Google Drive.

The original Git Repo is in the CSD folder.

Model architecture

The model CSD ("contrastive style descriptor") is initialized from the image encoder part of openai/clip-vit-large-patch14. Let $f$ be the function implemented by the image encoder. $f$ is implemented as a vision Transformer, that takes an image, and converts it into a $1024$-dimensional real-valued vector. This is then followed by a single matrix ("projection matrix") of dimensions $1024 \times 768$, converting it to a CLIP-embedding vector.

Now, remove the projection matrix. This gives us $g: \text{Image} \to \R^{1024}$. The output from $g$ is the feature vector. Now, add in two more projection matrices of dimensions $1024 \times 768$. The output from one is the style vector and the other is the content vector. All parameters of the resulting model was then finetuned by tadeephuy/GradientReversal for content style disentanglement, resulting in the final model.

The original paper actually stated that they trained two models, and one of them was based on ViT-B, but they did not release it.

The model takes as input real-valued tensors. To preprocess images, use the CLIP preprocessor. That is, use _, preprocess = clip.load("ViT-L/14"). Explicitly, the preprocessor performs the following operation:

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

See the documentation for CLIPImageProcessor for details.

Also, despite the names style vector and content vector, I have noticed by visual inspection that both are basically equally good for style embedding. I don't know why, but I guess that's life? (No, it's actually not supposed to happen. I don't know why it didn't really disentangle style and content. Maybe that's a question for a small research paper.)

You can see for yourself by changing the line style_output = output["style_output"].squeeze(0) to style_output = output["content_output"].squeeze(0) in the demo. The resulting t-SNE is still clustering by style, to my eyes equally well.

How to use it

Quickstart

Go to examples and run the example.ipynb notebook, then run tsne_visualization.py. It will say something like Running on http://127.0.0.1:49860. Click that link and enjoy the pretty interactive picture.

Loading the model

import copy
import torch
import torch.nn as nn
import clip
from transformers import CLIPProcessor
from huggingface_hub import PyTorchModelHubMixin
from transformers import PretrainedConfig

class CSDCLIPConfig(PretrainedConfig):
    model_type = "csd_clip"

    def __init__(
        self,
        name="csd_large",
        embedding_dim=1024,
        feature_dim=1024,
        content_dim=768,
        style_dim=768,
        content_proj_head="default",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.name = name
        self.embedding_dim = embedding_dim
        self.content_proj_head = content_proj_head
        self.task_specific_params = None  # Add this line

class CSD_CLIP(nn.Module, PyTorchModelHubMixin):
    """backbone + projection head"""
    def __init__(self, name='vit_large',content_proj_head='default'):
        super(CSD_CLIP, self).__init__()
        self.content_proj_head = content_proj_head
        if name == 'vit_large':
            clipmodel, _ = clip.load("ViT-L/14")
            self.backbone = clipmodel.visual
            self.embedding_dim = 1024
            self.feature_dim = 1024
            self.content_dim = 768
            self.style_dim = 768
            self.name = "csd_large"
        elif name == 'vit_base':
            clipmodel, _ = clip.load("ViT-B/16")
            self.backbone = clipmodel.visual
            self.embedding_dim = 768 
            self.feature_dim = 512
            self.content_dim = 512
            self.style_dim = 512
            self.name = "csd_base"
        else:
            raise Exception('This model is not implemented')

        self.last_layer_style = copy.deepcopy(self.backbone.proj)
        self.last_layer_content = copy.deepcopy(self.backbone.proj)

        self.backbone.proj = None
        
        self.config = CSDCLIPConfig(
            name=self.name,
            embedding_dim=self.embedding_dim,
            feature_dim=self.feature_dim,
            content_dim=self.content_dim,
            style_dim=self.style_dim,
            content_proj_head=self.content_proj_head
        )

    def get_config(self):
        return self.config.to_dict()

    @property
    def dtype(self):
        return self.backbone.conv1.weight.dtype
    
    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, input_data):
        
        feature = self.backbone(input_data)

        style_output = feature @ self.last_layer_style
        style_output = nn.functional.normalize(style_output, dim=1, p=2)

        content_output = feature @ self.last_layer_content
        content_output = nn.functional.normalize(content_output, dim=1, p=2)
        
        return feature, content_output, style_output

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CSD_CLIP.from_pretrained("yuxi-liu-wired/CSD")
model.to(device);

Loading the pipeline

import torch
from transformers import Pipeline
from typing import Union, List
from PIL import Image

class CSDCLIPPipeline(Pipeline):
    def __init__(self, model, processor, device=None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        super().__init__(model=model, tokenizer=None, device=device)
        self.processor = processor

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, images):
        if isinstance(images, (str, Image.Image)):
            images = [images]
        
        processed = self.processor(images=images, return_tensors="pt", padding=True, truncation=True)
        return {k: v.to(self.device) for k, v in processed.items()}

    def _forward(self, model_inputs):
        pixel_values = model_inputs['pixel_values'].to(self.model.dtype)
        with torch.no_grad():
            features, content_output, style_output = self.model(pixel_values)
        return {"features": features, "content_output": content_output, "style_output": style_output}

    def postprocess(self, model_outputs):
        return {
            "features": model_outputs["features"].cpu().numpy(),
            "content_output": model_outputs["content_output"].cpu().numpy(),
            "style_output": model_outputs["style_output"].cpu().numpy()
        }

    def __call__(self, images: Union[str, List[str], Image.Image, List[Image.Image]]):
        return super().__call__(images)

processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
pipeline = CSDCLIPPipeline(model=model, processor=processor, device=device)

An example application

First, load the model and the pipeline, as described above. Then, run the following to load the yuxi-liu-wired/style-content-grid-SDXL dataset, embed its style vectors, which is then written to a parquet output file.

import io
from PIL import Image
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm

def to_jpeg(image):
    buffered = io.BytesIO()
    if image.mode not in ("RGB"):
        image = image.convert("RGB")
    image.save(buffered, format='JPEG')
    return buffered.getvalue() 

def scale_image(image, max_resolution):
    if max(image.width, image.height) > max_resolution:
        image = image.resize((max_resolution, int(image.height * max_resolution / image.width)))
    return image

def process_dataset(pipeline, dataset_name, dataset_size=900, max_resolution=192):
    dataset = load_dataset(dataset_name, split='train')
    dataset = dataset.select(range(dataset_size))
    
    # Print the column names
    print("Dataset columns:", dataset.column_names)
    
    # Initialize lists to store results
    embeddings = []
    jpeg_images = []
    
    # Process each item in the dataset
    for item in tqdm(dataset, desc="Processing images"):
        try:
            img = item['image']
            
            # If img is a string (file path), load the image
            if isinstance(img, str):
                img = Image.open(img)


            output = pipeline(img)
            style_output = output["style_output"].squeeze(0)
            
            img = scale_image(img, max_resolution)
            jpeg_img = to_jpeg(img)
            
            # Append results to lists
            embeddings.append(style_output)
            jpeg_images.append(jpeg_img)
        except Exception as e:
            print(f"Error processing item: {e}")
    
    # Create a DataFrame with the results
    df = pd.DataFrame({
        'embedding': embeddings,
        'image': jpeg_images
    })
    
    df.to_parquet('processed_dataset.parquet')
    print("Processing complete. Results saved to 'processed_dataset.parquet'")

process_dataset(pipeline, "yuxi-liu-wired/style-content-grid-SDXL", 
                dataset_size=900, max_resolution=192)

After that, you can go to examples and run tsne_visualization.py to get an interactive Dash app browser for the images.

Downloads last month
1,430
Safetensors
Model size
305M params
Tensor type
F32
·
FP16
·
Inference API
Unable to determine this model's library. Check the docs .

Model tree for yuxi-liu-wired/CSD

Quantized
(2)
this model