import gradio as gr
import numpy as np
import cv2
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
from skimage.metrics import structural_similarity as ssim
import os
import tempfile
from PIL import Image
class ImageCharacterClassifier:
def __init__(self, similarity_threshold=0.5):
# Initialize ResNet50 model without top classification layer
self.model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
self.similarity_threshold = similarity_threshold
def load_and_preprocess_image(self, image_path, target_size=(224, 224)):
# Load and preprocess image for ResNet50
img = image.load_img(image_path, target_size=target_size)
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
return img_array
def extract_features(self, image_path):
# Extract deep features using ResNet50
preprocessed_img = self.load_and_preprocess_image(image_path)
features = self.model.predict(preprocessed_img)
return features
def calculate_ssim(self, img1_path, img2_path):
# Calculate SSIM between two images
img1 = cv2.imread(img1_path)
img2 = cv2.imread(img2_path)
if img1 is None or img2 is None:
return 0.0
# Convert to grayscale if images are in color
if len(img1.shape) == 3:
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
if len(img2.shape) == 3:
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
# Resize images to same dimensions
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
score = ssim(img1, img2)
return score
def process_images(reference_image, comparison_images, similarity_threshold):
try:
if reference_image is None:
return "Please upload a reference image.", []
if not comparison_images:
return "Please upload comparison images.", []
# Create temporary directory for saving uploaded files
with tempfile.TemporaryDirectory() as temp_dir:
# Initialize classifier with the threshold
classifier = ImageCharacterClassifier(similarity_threshold=similarity_threshold)
# Save reference image
ref_path = os.path.join(temp_dir, "reference.jpg")
cv2.imwrite(ref_path, cv2.cvtColor(reference_image, cv2.COLOR_RGB2BGR))
results = []
html_output = """
Results
Reference image compared with uploaded images
"""
# Extract reference features once
ref_features = classifier.extract_features(ref_path)
# Process each comparison image
for i, comp_image in enumerate(comparison_images):
try:
# Save comparison image
comp_path = os.path.join(temp_dir, f"comparison_{i}.jpg")
try:
# First attempt: Try using PIL
with Image.open(comp_image.name) as img:
img = img.convert('RGB')
img_array = np.array(img)
cv2.imwrite(comp_path, cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
except Exception as e1:
print(f"PIL failed: {str(e1)}")
# Second attempt: Try using OpenCV directly
img = cv2.imread(comp_image.name)
if img is not None:
cv2.imwrite(comp_path, img)
else:
raise ValueError(f"Could not read image: {comp_image.name}")
# Calculate SSIM for structural similarity
ssim_score = classifier.calculate_ssim(ref_path, comp_path)
# Extract features for physical feature comparison
comp_features = classifier.extract_features(comp_path)
# Calculate feature differences for physical features
feature_diff = np.abs(ref_features - comp_features)
# Calculate different aspects of similarity
avg_feature_diff = np.mean(feature_diff)
max_feature_diff = np.max(feature_diff)
feature_similarity = np.dot(ref_features.flatten(),
comp_features.flatten()) / (
np.linalg.norm(ref_features) * np.linalg.norm(comp_features))
# Stricter similarity criteria
is_similar = True # Start with assumption of similarity
reason = "Images are similar"
# First check for major physical feature differences (like misplaced eyes)
if max_feature_diff > 0.85 or avg_feature_diff > 0.5:
is_similar = False
reason = "Major physical differences detected (missing or misplaced features)"
# Then check for overall structural similarity
elif ssim_score < 0.4: # Lowered SSIM threshold
is_similar = False
reason = "Overall structure is too different"
# Finally check for feature similarity
elif feature_similarity < 0.5:
is_similar = False
reason = "Features don't match well enough"
# Debug information
print(f"\nDebug for {os.path.basename(comp_image.name)}:")
print(f"SSIM Score: {ssim_score:.3f}")
print(f"Max Feature Difference: {max_feature_diff:.3f}")
print(f"Average Feature Difference: {avg_feature_diff:.3f}")
print(f"Feature Similarity: {feature_similarity:.3f}")
# Create HTML output with improved styling and reason
status_color = "#27ae60" if is_similar else "#c0392b" # Green or Red
status_text = "SIMILAR" if is_similar else "NOT SIMILAR"
status_icon = "✓" if is_similar else "✗"
html_output += f"""
{status_icon}
{os.path.basename(comp_image.name)}
{reason}
{status_text}
"""
# Read the processed image back for display
display_img = cv2.imread(comp_path)
if display_img is not None:
display_img = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB)
results.append(display_img)
except Exception as e:
print(f"Error processing {comp_image.name}: {str(e)}")
html_output += f"""
Error processing: {os.path.basename(comp_image.name)}
{str(e)}
"""
return html_output, results
except Exception as e:
print(f"Main error: {str(e)}")
return f"""
""", []
# Update the interface creation
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("# Image Similarity Classifier")
gr.Markdown("Upload a reference image and up to 10 comparison images to check similarity.")
with gr.Row():
with gr.Column():
reference_input = gr.Image(
label="Reference Image",
type="numpy",
image_mode="RGB"
)
comparison_input = gr.File(
label="Comparison Images (Upload up to 10)",
file_count="multiple",
file_types=["image"],
maximum=10
)
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Similarity Threshold"
)
submit_button = gr.Button("Compare Images", variant="primary")
with gr.Column():
output_html = gr.HTML(label="Results")
output_gallery = gr.Gallery(
label="Processed Images",
columns=5,
show_label=True,
height="auto"
)
submit_button.click(
fn=process_images,
inputs=[reference_input, comparison_input, threshold_slider],
outputs=[output_html, output_gallery]
)
return interface
# Launch the app
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)