Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import clip | |
import pandas as pd | |
import hashlib | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import onnxruntime as ort | |
import requests | |
def binary_array_to_hex(arr: np.ndarray) -> str: | |
"""Convert a binary array to a hex string.""" | |
bit_string = ''.join(str(int(b)) for b in arr.flatten()) | |
width = int(np.ceil(len(bit_string) / 4)) | |
return '{:0>{width}x}'.format(int(bit_string, 2), width=width) | |
def phash(image: Image.Image, hash_size: int = 8, highfreq_factor: int = 4) -> str: | |
"""Calculate the perceptual hash of an image.""" | |
if hash_size < 2: | |
raise ValueError('Hash size must be greater than or equal to 2') | |
import scipy.fftpack | |
img_size = hash_size * highfreq_factor | |
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS) | |
pixels = np.asarray(image) | |
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1) | |
dctlowfreq = dct[:hash_size, :hash_size] | |
med = np.median(dctlowfreq) | |
diff = dctlowfreq > med | |
return binary_array_to_hex(diff) | |
def normalize(a: np.ndarray, axis: int = -1, order: int = 2) -> np.ndarray: | |
"""Normalize a numpy array.""" | |
l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) | |
l2[l2 == 0] = 1 | |
return a / np.expand_dims(l2, axis) | |
def convert_numpy_types(data): | |
"""Convert numpy types to Python native types.""" | |
if isinstance(data, dict): | |
return {key: convert_numpy_types(value) for key, value in data.items()} | |
elif isinstance(data, list): | |
return [convert_numpy_types(item) for item in data] | |
elif isinstance(data, np.float64): | |
return float(data) | |
elif isinstance(data, np.int64): | |
return int(data) | |
else: | |
return data | |
def download_model(url: str, path: str) -> None: | |
"""Download a model from a URL and save it to a file.""" | |
response = requests.get(url) | |
with open(path, 'wb') as f: | |
f.write(response.content) | |
# Load models outside the function | |
onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx" | |
onnx_path = "aesthetic_score_mlp.onnx" | |
download_model(onnx_url, onnx_path) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ort_session = ort.InferenceSession(onnx_path) | |
model, preprocess = clip.load("ViT-L/14", device=device) | |
def predict(image: np.ndarray) -> dict: | |
"""Predict the aesthetic score of an image using CLIP.""" | |
image = Image.fromarray(image) | |
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var() | |
phash_str = phash(image) | |
md5_hash = hashlib.md5(image.tobytes()).hexdigest() | |
sha1_hash = hashlib.sha1(image.tobytes()).hexdigest() | |
inputs = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
img_emb = model.encode_image(inputs) | |
img_emb = normalize(img_emb.cpu().numpy()) | |
ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)} | |
ort_outs = ort_session.run(None, ort_inputs) | |
prediction = ort_outs[0].item() | |
result = { | |
"clip_aesthetic": prediction, | |
"phash": phash_str, | |
"md5": md5_hash, | |
"sha1": sha1_hash, | |
"laplacian_variance": laplacian_variance | |
} | |
return convert_numpy_types(result) | |
title = "CLIP Aesthetic Score" | |
description = "Upload an image to predict its aesthetic score using the CLIP model and other metrics." | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="numpy"), | |
outputs=gr.JSON(label="Result"), | |
title=title, | |
description=description, | |
examples=[["example1.jpg"], ["example2.jpg"]] | |
).launch() |