Spaces:
Runtime error
Runtime error
# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py. | |
import os | |
from collections import OrderedDict | |
from os import PathLike | |
from typing import Final | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
SiglipImageProcessor, | |
SiglipVisionConfig, | |
SiglipVisionModel, | |
logging, | |
) | |
from transformers.image_processing_utils import BatchFeature | |
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention | |
logging.set_verbosity_error() | |
URL: Final[str] = ( | |
"https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth" | |
) | |
class AestheticPredictorV2_5Head(nn.Module): | |
def __init__(self, config: SiglipVisionConfig) -> None: | |
super().__init__() | |
self.scoring_head = nn.Sequential( | |
nn.Linear(config.hidden_size, 1024), | |
nn.Dropout(0.5), | |
nn.Linear(1024, 128), | |
nn.Dropout(0.5), | |
nn.Linear(128, 64), | |
nn.Dropout(0.5), | |
nn.Linear(64, 16), | |
nn.Dropout(0.2), | |
nn.Linear(16, 1), | |
) | |
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: | |
return self.scoring_head(image_embeds) | |
class AestheticPredictorV2_5Model(SiglipVisionModel): | |
PATCH_SIZE = 14 | |
def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None: | |
super().__init__(config, *args, **kwargs) | |
self.layers = AestheticPredictorV2_5Head(config) | |
self.post_init() | |
def forward( | |
self, | |
pixel_values: torch.FloatTensor | None = None, | |
labels: torch.Tensor | None = None, | |
return_dict: bool | None = None, | |
) -> tuple | ImageClassifierOutputWithNoAttention: | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
outputs = super().forward( | |
pixel_values=pixel_values, | |
return_dict=return_dict, | |
) | |
image_embeds = outputs.pooler_output | |
image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True) | |
prediction = self.layers(image_embeds_norm) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.MSELoss() | |
loss = loss_fct() | |
if not return_dict: | |
return (loss, prediction, image_embeds) | |
return ImageClassifierOutputWithNoAttention( | |
loss=loss, | |
logits=prediction, | |
hidden_states=image_embeds, | |
) | |
class AestheticPredictorV2_5Processor(SiglipImageProcessor): | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
def __call__(self, *args, **kwargs) -> BatchFeature: | |
return super().__call__(*args, **kwargs) | |
def from_pretrained( | |
self, | |
pretrained_model_name_or_path: str | |
| PathLike = "google/siglip-so400m-patch14-384", | |
*args, | |
**kwargs, | |
) -> "AestheticPredictorV2_5Processor": | |
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | |
def convert_v2_5_from_siglip( | |
predictor_name_or_path: str | PathLike | None = None, | |
encoder_model_name: str = "google/siglip-so400m-patch14-384", | |
*args, | |
**kwargs, | |
) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]: | |
model = AestheticPredictorV2_5Model.from_pretrained( | |
encoder_model_name, *args, **kwargs | |
) | |
processor = AestheticPredictorV2_5Processor.from_pretrained( | |
encoder_model_name, *args, **kwargs | |
) | |
if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path): | |
state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu") | |
else: | |
state_dict = torch.load(predictor_name_or_path, map_location="cpu") | |
assert isinstance(state_dict, OrderedDict) | |
model.layers.load_state_dict(state_dict) | |
model.eval() | |
return model, processor |