Spaces:
Sleeping
Sleeping
import os | |
import ssl | |
import cv2 | |
import torch | |
import certifi | |
import numpy as np | |
import gradio as gr | |
import torch.nn as nn | |
from torchvision import models | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageEnhance | |
import torchvision.transforms as transforms | |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Number of classes | |
num_classes = 6 | |
# Load the pre-trained ResNet model | |
model = models.resnet152(pretrained=True) | |
for param in model.parameters(): | |
param.requires_grad = False # Freeze feature extractor | |
# Modify the classifier for 6 classes with an additional hidden layer | |
model.fc = nn.Sequential( | |
nn.Linear(model.fc.in_features, 512), | |
nn.ReLU(), | |
nn.Linear(512, num_classes) | |
) | |
# Load trained weights | |
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Class labels | |
class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse'] | |
# Image transformation function | |
def transform_image(image): | |
"""Preprocess the input image.""" | |
transform = transforms.Compose([ | |
transforms.Resize((32, 32)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
return img_tensor | |
# Apply feature filters | |
def apply_filters(image, brightness, contrast, hue): | |
"""Adjust Brightness, Contrast, and Hue of the input image.""" | |
image = image.convert("RGB") # Ensure RGB mode | |
# Adjust brightness | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(brightness) | |
# Adjust contrast | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(contrast) | |
# Adjust hue (convert to HSV, modify, and convert back) | |
image = np.array(image) | |
hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) | |
hsv_image[..., 0] = (hsv_image[..., 0] + hue * 180) % 180 # Adjust hue | |
image = cv2.cvtColor(hsv_image.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
return Image.fromarray(image) | |
# Superimposition function | |
def superimpose_images(base_image, overlay_image, alpha): | |
"""Superimpose overlay_image onto base_image with a given alpha blend.""" | |
if overlay_image is None: | |
return base_image # No overlay, return base image as is | |
# Resize overlay image to match base image | |
overlay_image = overlay_image.resize(base_image.size) | |
# Convert to numpy arrays | |
base_array = np.array(base_image).astype(float) | |
overlay_array = np.array(overlay_image).astype(float) | |
# Blend images | |
blended_array = (1 - alpha) * base_array + alpha * overlay_array | |
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) | |
return Image.fromarray(blended_array) | |
# Prediction function | |
def predict(image, brightness, contrast, hue, overlay_image, alpha): | |
"""Apply filters, superimpose, classify image, and visualize results.""" | |
if image is None: | |
return None, None, None | |
# Apply feature filters | |
processed_image = apply_filters(image, brightness, contrast, hue) | |
# Superimpose overlay image | |
final_image = superimpose_images(processed_image, overlay_image, alpha) | |
# Convert PIL Image to Tensor | |
image_tensor = transform_image(final_image) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0] | |
# Generate Bar Chart | |
with plt.xkcd(): | |
fig, ax = plt.subplots(figsize=(5, 3)) | |
ax.bar(class_labels, probabilities, color='skyblue') | |
ax.set_ylabel("Probability") | |
ax.set_title("Class Probabilities") | |
ax.set_ylim([0, 1]) | |
for i, v in enumerate(probabilities): | |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10) | |
return final_image, fig | |
# Gradio Interface | |
with gr.Blocks() as interface: | |
gr.Markdown("<h2 style='text-align: center;'>Image Classifier with Superimposition & Adjustable Filters</h2>") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Base Image") | |
overlay_input = gr.Image(type="pil", label="Upload Overlay Image (Optional)") | |
brightness = gr.Slider(0.5, 2.0, value=1.0, label="Brightness") | |
contrast = gr.Slider(0.5, 2.0, value=1.0, label="Contrast") | |
hue = gr.Slider(-0.5, 0.5, value=0.0, label="Hue") | |
alpha = gr.Slider(0.0, 1.0, value=0.5, label="Overlay Weight (Alpha)") | |
with gr.Column(): | |
processed_image = gr.Image(label="Final Processed Image") | |
bar_chart = gr.Plot(label="Class Probabilities") | |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha] | |
outputs = [processed_image, bar_chart] | |
# Event listeners for real-time updates | |
image_input.change(predict, inputs=inputs, outputs=outputs) | |
overlay_input.change(predict, inputs=inputs, outputs=outputs) | |
brightness.change(predict, inputs=inputs, outputs=outputs) | |
contrast.change(predict, inputs=inputs, outputs=outputs) | |
hue.change(predict, inputs=inputs, outputs=outputs) | |
alpha.change(predict, inputs=inputs, outputs=outputs) | |
interface.launch() |