clip_aes_onnx / app.py
haor's picture
Update app.py
2e8cc5c verified
raw
history blame
3.69 kB
import gradio as gr
import torch
import clip
import pandas as pd
import hashlib
import numpy as np
import cv2
from PIL import Image
import onnxruntime as ort
import requests
def binary_array_to_hex(arr: np.ndarray) -> str:
"""Convert a binary array to a hex string."""
bit_string = ''.join(str(int(b)) for b in arr.flatten())
width = int(np.ceil(len(bit_string) / 4))
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
def phash(image: Image.Image, hash_size: int = 8, highfreq_factor: int = 4) -> str:
"""Calculate the perceptual hash of an image."""
if hash_size < 2:
raise ValueError('Hash size must be greater than or equal to 2')
import scipy.fftpack
img_size = hash_size * highfreq_factor
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
pixels = np.asarray(image)
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return binary_array_to_hex(diff)
def normalize(a: np.ndarray, axis: int = -1, order: int = 2) -> np.ndarray:
"""Normalize a numpy array."""
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def convert_numpy_types(data):
"""Convert numpy types to Python native types."""
if isinstance(data, dict):
return {key: convert_numpy_types(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_numpy_types(item) for item in data]
elif isinstance(data, np.float64):
return float(data)
elif isinstance(data, np.int64):
return int(data)
else:
return data
def download_model(url: str, path: str) -> None:
"""Download a model from a URL and save it to a file."""
response = requests.get(url)
with open(path, 'wb') as f:
f.write(response.content)
# Load models outside the function
onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx"
onnx_path = "aesthetic_score_mlp.onnx"
download_model(onnx_url, onnx_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
ort_session = ort.InferenceSession(onnx_path)
model, preprocess = clip.load("ViT-L/14", device=device)
def predict(image: np.ndarray) -> dict:
"""Predict the aesthetic score of an image using CLIP."""
image = Image.fromarray(image)
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
phash_str = phash(image)
md5_hash = hashlib.md5(image.tobytes()).hexdigest()
sha1_hash = hashlib.sha1(image.tobytes()).hexdigest()
inputs = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
img_emb = model.encode_image(inputs)
img_emb = normalize(img_emb.cpu().numpy())
ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
prediction = ort_outs[0].item()
result = {
"clip_aesthetic": prediction,
"phash": phash_str,
"md5": md5_hash,
"sha1": sha1_hash,
"laplacian_variance": laplacian_variance
}
return convert_numpy_types(result)
title = "CLIP Aesthetic Score"
description = "Upload an image to predict its aesthetic score using the CLIP model and other metrics."
gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.JSON(label="Result"),
title=title,
description=description,
examples=[["example1.jpg"], ["example2.jpg"]]
).launch()