Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import clip | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
class AestheticPredictor: | |
"""Aesthetic Score Predictor. | |
Args: | |
clip_model_dir (str): Path to the directory of the CLIP model. | |
sac_model_path (str): Path to the pre-trained SAC model. | |
device (str): Device to use for computation ("cuda" or "cpu"). | |
""" | |
def __init__(self, clip_model_dir=None, sac_model_path=None, device=None): | |
self.device = device or ( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
if clip_model_dir is None: | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" | |
) | |
suffix = "aesthetic" | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
) | |
clip_model_dir = os.path.join(model_path, suffix) | |
if sac_model_path is None: | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*" | |
) | |
suffix = "aesthetic" | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
) | |
sac_model_path = os.path.join( | |
model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth" | |
) | |
self.clip_model, self.preprocess = self._load_clip_model( | |
clip_model_dir | |
) | |
self.sac_model = self._load_sac_model(sac_model_path, input_size=768) | |
class MLP(pl.LightningModule): # noqa | |
def __init__(self, input_size): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(input_size, 1024), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
nn.Linear(16, 1), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
def normalized(a, axis=-1, order=2): | |
"""Normalize the array to unit norm.""" | |
l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) | |
l2[l2 == 0] = 1 | |
return a / np.expand_dims(l2, axis) | |
def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"): | |
"""Load the CLIP model.""" | |
model, preprocess = clip.load( | |
model_name, download_root=model_dir, device=self.device | |
) | |
return model, preprocess | |
def _load_sac_model(self, model_path, input_size): | |
"""Load the SAC model.""" | |
model = self.MLP(input_size) | |
ckpt = torch.load(model_path) | |
model.load_state_dict(ckpt) | |
model.to(self.device) | |
model.eval() | |
return model | |
def predict(self, image_path): | |
"""Predict the aesthetic score for a given image. | |
Args: | |
image_path (str): Path to the image file. | |
Returns: | |
float: Predicted aesthetic score. | |
""" | |
pil_image = Image.open(image_path) | |
image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
# Extract CLIP features | |
image_features = self.clip_model.encode_image(image) | |
# Normalize features | |
normalized_features = self.normalized( | |
image_features.cpu().detach().numpy() | |
) | |
# Predict score | |
prediction = self.sac_model( | |
torch.from_numpy(normalized_features) | |
.type(torch.FloatTensor) | |
.to(self.device) | |
) | |
return prediction.item() | |
if __name__ == "__main__": | |
# Configuration | |
img_path = "/home/users/xinjie.wang/xinjie/asset3d-gen/outputs/imageto3d/demo_objects/bed/sample_0/sample_0_raw.png" # noqa | |
# clip_model_dir = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/clip" # noqa | |
# sac_model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/sac/sac+logos+ava1-l14-linearMSE.pth" # noqa | |
# Initialize the predictor | |
predictor = AestheticPredictor() | |
# Predict the aesthetic score | |
score = predictor.predict(img_path) | |
print("Aesthetic score predicted by the model:", score) | |