File size: 4,736 Bytes
cc4db04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436f0f6
1093e85
cc4db04
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import numpy as np
import tensorflow as tf
import logging
from PIL import Image
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
import scipy.fftpack
import time
import clip
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)

# Load models
resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")

# Preprocess function
def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
    start_time = time.time()
    img = keras_image.load_img(img_path, target_size=target_size)
    img_array = keras_image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_func(img_array)
    logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
    return img_array

# Feature extraction function
def extract_features(img_path, model, preprocess_func):
    img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
    start_time = time.time()
    features = model.predict(img_array)
    logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
    return features.flatten()

# Calculate cosine similarity
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# pHash related functions
def phashstr(image, hash_size=8, highfreq_factor=4):
    img_size = hash_size * highfreq_factor
    image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
    pixels = np.asarray(image)
    dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
    dctlowfreq = dct[:hash_size, :hash_size]
    med = np.median(dctlowfreq)
    diff = dctlowfreq > med
    return _binary_array_to_hex(diff.flatten())

def _binary_array_to_hex(arr):
    h = 0
    s = []
    for i, v in enumerate(arr):
        if v:
            h += 2**(i % 8)
        if (i % 8) == 7:
            s.append(hex(h)[2:].rjust(2, '0'))
            h = 0
    return ''.join(s)

def hamming_distance(hash1, hash2):
    if len(hash1) != len(hash2):
        raise ValueError("Hashes must be of the same length")
    return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))

def hamming_to_similarity(distance, hash_length):
    return (1 - distance / hash_length) * 100

# CLIP related functions
def extract_clip_features(image_path, model, preprocess):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
    with torch.no_grad():
        features = model.encode_image(image)
    return features.cpu().numpy().flatten()

# Main function
def compare_images(image1, image2, method):
    start_time = time.time()
    if method == 'pHash':
        img1 = Image.open(image1)
        img2 = Image.open(image2)
        hash1 = phashstr(img1)
        hash2 = phashstr(img2)
        distance = hamming_distance(hash1, hash2)
        similarity = hamming_to_similarity(distance, len(hash1) * 4)
    elif method == 'ResNet50':
        features1 = extract_features(image1, resnet_model, resnet_preprocess)
        features2 = extract_features(image2, resnet_model, resnet_preprocess)
        similarity = cosine_similarity(features1, features2)
    elif method == 'VGG16':
        features1 = extract_features(image1, vgg_model, vgg_preprocess)
        features2 = extract_features(image2, vgg_model, vgg_preprocess)
        similarity = cosine_similarity(features1, features2)
    elif method == 'CLIP':
        features1 = extract_clip_features(image1, clip_model, preprocess_clip)
        features2 = extract_clip_features(image2, clip_model, preprocess_clip)
        similarity = cosine_similarity(features1, features2)
    
    logging.info(f"Image comparison using {method} completed in {time.time() - start_time:.4f} seconds")
    return similarity

# Gradio interface
demo = gr.Interface(
    fn=compare_images,
    inputs=[
        gr.Image(type="filepath", label="Upload First Image"),
        gr.Image(type="filepath", label="Upload Second Image"),
        gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
    ],
    outputs=gr.Textbox(label="Similarity"),
    title="Image Similarity Comparison",
    description="Upload two images and select the comparison method.",
    examples=[
        ["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
        ["example1.png", "example2.png"]
    ]
)

demo.launch()