haor commited on
Commit
167e9d9
1 Parent(s): 31ab6ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ import pandas as pd
5
+ import hashlib
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image
9
+ import onnxruntime as ort
10
+ import requests
11
+
12
+ def _binary_array_to_hex(arr):
13
+ bit_string = ''.join(str(b) for b in 1 * arr.flatten())
14
+ width = int(np.ceil(len(bit_string) / 4))
15
+ return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
16
+
17
+ def phashstr(image, hash_size=8, highfreq_factor=4):
18
+ if hash_size < 2:
19
+ raise ValueError('Hash size must be greater than or equal to 2')
20
+ import scipy.fftpack
21
+ img_size = hash_size * highfreq_factor
22
+ image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
23
+ pixels = np.asarray(image)
24
+ dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
25
+ dctlowfreq = dct[:hash_size, :hash_size]
26
+ med = np.median(dctlowfreq)
27
+ diff = dctlowfreq > med
28
+ return _binary_array_to_hex(diff.flatten())
29
+
30
+ def normalized(a, axis=-1, order=2):
31
+ l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
32
+ l2[l2 == 0] = 1
33
+ return a / np.expand_dims(l2, axis)
34
+
35
+ def convert_numpy_types(data):
36
+ if isinstance(data, dict):
37
+ return {key: convert_numpy_types(value) for key, value in data.items()}
38
+ elif isinstance(data, list):
39
+ return [convert_numpy_types(item) for item in data]
40
+ elif isinstance(data, np.float64):
41
+ return float(data)
42
+ elif isinstance(data, np.int64):
43
+ return int(data)
44
+ else:
45
+ return data
46
+
47
+ def download_onnx_model(url, filename):
48
+ response = requests.get(url)
49
+ with open(filename, 'wb') as f:
50
+ f.write(response.content)
51
+
52
+ def predict(image):
53
+ onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx"
54
+ onnx_path = "aesthetic_score_mlp.onnx"
55
+ download_onnx_model(onnx_url, onnx_path)
56
+
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ ort_session = ort.InferenceSession(onnx_path)
59
+
60
+ model2, preprocess = clip.load("ViT-L/14", device=device)
61
+ image = Image.fromarray(image)
62
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
63
+ laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
64
+ phash = phashstr(image)
65
+ md5 = hashlib.md5(image.tobytes()).hexdigest()
66
+ sha1 = hashlib.sha1(image.tobytes()).hexdigest()
67
+
68
+ inputs = preprocess(image).unsqueeze(0).to(device)
69
+ with torch.no_grad():
70
+ img_emb = model2.encode_image(inputs)
71
+ img_emb = normalized(img_emb.cpu().numpy())
72
+
73
+ ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)}
74
+ ort_outs = ort_session.run(None, ort_inputs)
75
+ prediction = ort_outs[0].item()
76
+
77
+ result = {
78
+ "clip_aesthetic": prediction,
79
+ "phash": phash,
80
+ "md5": md5,
81
+ "sha1": sha1,
82
+ "laplacian_variance": laplacian_variance
83
+ }
84
+ return convert_numpy_types(result)
85
+
86
+ title = "CLIP Aesthetic Score"
87
+ description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics."
88
+
89
+ gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Image(type="numpy"),
92
+ outputs=gr.JSON(label="Result"),
93
+ title=title,
94
+ description=description,
95
+ examples=[["example1.jpg"], ["example2.jpg"]]
96
+ ).launch()