haor commited on
Commit
a45bc86
1 Parent(s): b1d96f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -32
app.py CHANGED
@@ -9,14 +9,17 @@ from PIL import Image
9
  import onnxruntime as ort
10
  import requests
11
 
12
- def _binary_array_to_hex(arr):
13
- bit_string = ''.join(str(b) for b in 1 * arr.flatten())
 
14
  width = int(np.ceil(len(bit_string) / 4))
15
  return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
16
 
17
- def phashstr(image, hash_size=8, highfreq_factor=4):
 
18
  if hash_size < 2:
19
  raise ValueError('Hash size must be greater than or equal to 2')
 
20
  import scipy.fftpack
21
  img_size = hash_size * highfreq_factor
22
  image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
@@ -25,14 +28,16 @@ def phashstr(image, hash_size=8, highfreq_factor=4):
25
  dctlowfreq = dct[:hash_size, :hash_size]
26
  med = np.median(dctlowfreq)
27
  diff = dctlowfreq > med
28
- return _binary_array_to_hex(diff.flatten())
29
 
30
- def normalized(a, axis=-1, order=2):
 
31
  l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
32
  l2[l2 == 0] = 1
33
  return a / np.expand_dims(l2, axis)
34
 
35
  def convert_numpy_types(data):
 
36
  if isinstance(data, dict):
37
  return {key: convert_numpy_types(value) for key, value in data.items()}
38
  elif isinstance(data, list):
@@ -44,50 +49,53 @@ def convert_numpy_types(data):
44
  else:
45
  return data
46
 
47
- def download_onnx_model(url, filename):
 
48
  response = requests.get(url)
49
- with open(filename, 'wb') as f:
50
  f.write(response.content)
51
 
52
- def predict(image):
53
- onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx"
54
- onnx_path = "aesthetic_score_mlp.onnx"
55
- download_onnx_model(onnx_url, onnx_path)
56
-
57
- device = "cuda" if torch.cuda.is_available() else "cpu"
58
- ort_session = ort.InferenceSession(onnx_path)
59
-
60
- model2, preprocess = clip.load("ViT-L/14", device=device)
61
- image = Image.fromarray(image)
 
 
62
  image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
63
  laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
64
- phash = phashstr(image)
65
- md5 = hashlib.md5(image.tobytes()).hexdigest()
66
- sha1 = hashlib.sha1(image.tobytes()).hexdigest()
67
 
68
  inputs = preprocess(image).unsqueeze(0).to(device)
 
69
  with torch.no_grad():
70
- img_emb = model2.encode_image(inputs)
71
- img_emb = normalized(img_emb.cpu().numpy())
72
-
73
- ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)}
74
- ort_outs = ort_session.run(None, ort_inputs)
75
- prediction = ort_outs[0].item()
76
-
77
  result = {
78
  "clip_aesthetic": prediction,
79
- "phash": phash,
80
- "md5": md5,
81
- "sha1": sha1,
82
  "laplacian_variance": laplacian_variance
83
  }
84
  return convert_numpy_types(result)
85
 
86
  title = "CLIP Aesthetic Score"
87
- description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics."
88
 
89
  gr.Interface(
90
- fn=predict,
91
  inputs=gr.Image(type="numpy"),
92
  outputs=gr.JSON(label="Result"),
93
  title=title,
 
9
  import onnxruntime as ort
10
  import requests
11
 
12
+ def binary_array_to_hex(arr: np.ndarray) -> str:
13
+ """Convert a binary array to a hex string."""
14
+ bit_string = ''.join(str(b) for b in arr.flatten())
15
  width = int(np.ceil(len(bit_string) / 4))
16
  return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
17
 
18
+ def phash(image: Image.Image, hash_size: int = 8, highfreq_factor: int = 4) -> str:
19
+ """Calculate the perceptual hash of an image."""
20
  if hash_size < 2:
21
  raise ValueError('Hash size must be greater than or equal to 2')
22
+
23
  import scipy.fftpack
24
  img_size = hash_size * highfreq_factor
25
  image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
 
28
  dctlowfreq = dct[:hash_size, :hash_size]
29
  med = np.median(dctlowfreq)
30
  diff = dctlowfreq > med
31
+ return binary_array_to_hex(diff)
32
 
33
+ def normalize(a: np.ndarray, axis: int = -1, order: int = 2) -> np.ndarray:
34
+ """Normalize a numpy array."""
35
  l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
36
  l2[l2 == 0] = 1
37
  return a / np.expand_dims(l2, axis)
38
 
39
  def convert_numpy_types(data):
40
+ """Convert numpy types to Python native types."""
41
  if isinstance(data, dict):
42
  return {key: convert_numpy_types(value) for key, value in data.items()}
43
  elif isinstance(data, list):
 
49
  else:
50
  return data
51
 
52
+ def download_model(url: str, path: str) -> None:
53
+ """Download a model from a URL and save it to a file."""
54
  response = requests.get(url)
55
+ with open(path, 'wb') as f:
56
  f.write(response.content)
57
 
58
+ # Load models outside the function
59
+ onnx_url = "https://huggingface.co/haor/aesthetics/resolve/main/aesthetic_score_mlp.onnx"
60
+ onnx_path = "aesthetic_score_mlp.onnx"
61
+ download_model(onnx_url, onnx_path)
62
+
63
+ device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ ort_session = ort.InferenceSession(onnx_path)
65
+ model, preprocess = clip.load("ViT-L/14", device=device)
66
+
67
+ def predict(image: np.ndarray) -> dict:
68
+ """Predict the aesthetic score of an image using CLIP."""
69
+ image = Image.fromarray(image)
70
  image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
71
  laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
72
+ phash_str = phash(image)
73
+ md5_hash = hashlib.md5(image.tobytes()).hexdigest()
74
+ sha1_hash = hashlib.sha1(image.tobytes()).hexdigest()
75
 
76
  inputs = preprocess(image).unsqueeze(0).to(device)
77
+
78
  with torch.no_grad():
79
+ img_emb = model.encode_image(inputs)
80
+ img_emb = normalize(img_emb.cpu().numpy())
81
+ ort_inputs = {ort_session.get_inputs()[0].name: img_emb.astype(np.float32)}
82
+ ort_outs = ort_session.run(None, ort_inputs)
83
+ prediction = ort_outs[0].item()
84
+
 
85
  result = {
86
  "clip_aesthetic": prediction,
87
+ "phash": phash_str,
88
+ "md5": md5_hash,
89
+ "sha1": sha1_hash,
90
  "laplacian_variance": laplacian_variance
91
  }
92
  return convert_numpy_types(result)
93
 
94
  title = "CLIP Aesthetic Score"
95
+ description = "Upload an image to predict its aesthetic score using the CLIP model and other metrics."
96
 
97
  gr.Interface(
98
+ fn=predict,
99
  inputs=gr.Image(type="numpy"),
100
  outputs=gr.JSON(label="Result"),
101
  title=title,