meepmoo's picture
Upload folder using huggingface_hub
0dcccdd verified
raw
history blame
4.01 kB
# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py.
import os
from collections import OrderedDict
from os import PathLike
from typing import Final
import torch
import torch.nn as nn
from transformers import (
SiglipImageProcessor,
SiglipVisionConfig,
SiglipVisionModel,
logging,
)
from transformers.image_processing_utils import BatchFeature
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
logging.set_verbosity_error()
URL: Final[str] = (
"https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
)
class AestheticPredictorV2_5Head(nn.Module):
def __init__(self, config: SiglipVisionConfig) -> None:
super().__init__()
self.scoring_head = nn.Sequential(
nn.Linear(config.hidden_size, 1024),
nn.Dropout(0.5),
nn.Linear(1024, 128),
nn.Dropout(0.5),
nn.Linear(128, 64),
nn.Dropout(0.5),
nn.Linear(64, 16),
nn.Dropout(0.2),
nn.Linear(16, 1),
)
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
return self.scoring_head(image_embeds)
class AestheticPredictorV2_5Model(SiglipVisionModel):
PATCH_SIZE = 14
def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
super().__init__(config, *args, **kwargs)
self.layers = AestheticPredictorV2_5Head(config)
self.post_init()
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
labels: torch.Tensor | None = None,
return_dict: bool | None = None,
) -> tuple | ImageClassifierOutputWithNoAttention:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = super().forward(
pixel_values=pixel_values,
return_dict=return_dict,
)
image_embeds = outputs.pooler_output
image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
prediction = self.layers(image_embeds_norm)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct()
if not return_dict:
return (loss, prediction, image_embeds)
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=prediction,
hidden_states=image_embeds,
)
class AestheticPredictorV2_5Processor(SiglipImageProcessor):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs) -> BatchFeature:
return super().__call__(*args, **kwargs)
@classmethod
def from_pretrained(
self,
pretrained_model_name_or_path: str
| PathLike = "google/siglip-so400m-patch14-384",
*args,
**kwargs,
) -> "AestheticPredictorV2_5Processor":
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
def convert_v2_5_from_siglip(
predictor_name_or_path: str | PathLike | None = None,
encoder_model_name: str = "google/siglip-so400m-patch14-384",
*args,
**kwargs,
) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
model = AestheticPredictorV2_5Model.from_pretrained(
encoder_model_name, *args, **kwargs
)
processor = AestheticPredictorV2_5Processor.from_pretrained(
encoder_model_name, *args, **kwargs
)
if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
else:
state_dict = torch.load(predictor_name_or_path, map_location="cpu")
assert isinstance(state_dict, OrderedDict)
model.layers.load_state_dict(state_dict)
model.eval()
return model, processor