clip_aes_onnx / app.py
haor's picture
Create app.py
167e9d9 verified
raw
history blame
3.24 kB
import gradio as gr
import torch
import clip
import pandas as pd
import hashlib
import numpy as np
import cv2
from PIL import Image
import onnxruntime as ort
import requests
def _binary_array_to_hex(arr):
bit_string = ''.join(str(b) for b in 1 * arr.flatten())
width = int(np.ceil(len(bit_string) / 4))
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
def phashstr(image, hash_size=8, highfreq_factor=4):
if hash_size < 2:
raise ValueError('Hash size must be greater than or equal to 2')
import scipy.fftpack
img_size = hash_size * highfreq_factor
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
pixels = np.asarray(image)
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return _binary_array_to_hex(diff.flatten())
def normalized(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def convert_numpy_types(data):
if isinstance(data, dict):
return {key: convert_numpy_types(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_numpy_types(item) for item in data]
elif isinstance(data, np.float64):
return float(data)
elif isinstance(data, np.int64):
return int(data)
else:
return data
def download_onnx_model(url, filename):
response = requests.get(url)
with open(filename, 'wb') as f:
f.write(response.content)
def predict(image):
onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx"
onnx_path = "aesthetic_score_mlp.onnx"
download_onnx_model(onnx_url, onnx_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
ort_session = ort.InferenceSession(onnx_path)
model2, preprocess = clip.load("ViT-L/14", device=device)
image = Image.fromarray(image)
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
phash = phashstr(image)
md5 = hashlib.md5(image.tobytes()).hexdigest()
sha1 = hashlib.sha1(image.tobytes()).hexdigest()
inputs = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
img_emb = model2.encode_image(inputs)
img_emb = normalized(img_emb.cpu().numpy())
ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
prediction = ort_outs[0].item()
result = {
"clip_aesthetic": prediction,
"phash": phash,
"md5": md5,
"sha1": sha1,
"laplacian_variance": laplacian_variance
}
return convert_numpy_types(result)
title = "CLIP Aesthetic Score"
description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics."
gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.JSON(label="Result"),
title=title,
description=description,
examples=[["example1.jpg"], ["example2.jpg"]]
).launch()