Spaces:
Sleeping
Sleeping
import os | |
import ssl | |
import cv2 | |
import torch | |
import certifi | |
import numpy as np | |
import gradio as gr | |
import torch.nn as nn | |
from torchvision import models | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageEnhance | |
import torchvision.transforms as transforms | |
import urllib | |
import json | |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Number of classes | |
num_classes = 6 | |
''' | |
resnet imagenet | |
''' | |
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json" | |
with urllib.request.urlopen(url) as f: | |
imagenet_classes = json.load(f) | |
## Convert to dictionary format {0: "tench", 1: "goldfish", ..., 999: "toilet tissue"} | |
cifar10_classes = {int(k): v[1] for k, v in imagenet_classes.items()} | |
# Load the pre-trained ResNet model | |
model = models.resnet152(pretrained=True) | |
# Modify the classifier for 6 classes with an additional hidden layer | |
# model.fc = nn.Sequential( | |
# nn.Linear(model.fc.in_features, 512), | |
# nn.ReLU(), | |
# nn.Linear(512, num_classes) | |
# ) | |
# Load trained weights | |
# model.load_state_dict(torch.load('model_old.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Class labels | |
# class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse'] | |
class_labels = [cifar10_classes[i] for i in range(len(cifar10_classes))] | |
class MultiLayerGradCAM: | |
def __init__(self, model, target_layers=None): | |
self.model = model | |
self.target_layers = target_layers if target_layers else ['layer4'] | |
self.activations = [] | |
self.gradients = [] | |
self.handles = [] | |
# Register hooks | |
for name, module in self.model.named_modules(): | |
if name in self.target_layers: | |
self.handles.append( | |
module.register_forward_hook(self._forward_hook) | |
) | |
self.handles.append( | |
module.register_backward_hook(self._backward_hook) | |
) | |
def _forward_hook(self, module, input, output): | |
self.activations.append(output.detach()) | |
def _backward_hook(self, module, grad_input, grad_output): | |
self.gradients.append(grad_output[0].detach()) | |
def _find_layer(self, layer_name): | |
for name, module in self.model.named_modules(): | |
if name == layer_name: | |
return module | |
raise ValueError(f"Layer {layer_name} not found in model") | |
def generate(self, input_tensor, target_class=None): | |
device = next(self.model.parameters()).device | |
self.model.zero_grad() | |
# Forward pass | |
output = self.model(input_tensor.to(device)) | |
pred_class = torch.argmax(output).item() if target_class is None else target_class | |
# Backward pass | |
one_hot = torch.zeros_like(output) | |
one_hot[0][pred_class] = 1 | |
output.backward(gradient=one_hot) | |
# Process activations and gradients | |
heatmaps = [] | |
for act, grad in zip(self.activations, reversed(self.gradients)): | |
# Compute weights | |
weights = F.adaptive_avg_pool2d(grad, 1) | |
# Create weighted combination of activation maps | |
cam = torch.mul(act, weights).sum(dim=1, keepdim=True) | |
cam = F.relu(cam) | |
print(cam.shape) | |
# Upsample to input size | |
cam = F.interpolate(cam, size=input_tensor.shape[2:], | |
mode='bilinear', align_corners=False) | |
heatmaps.append(cam.squeeze().cpu().numpy()) | |
# Combine heatmaps from different layers | |
combined_heatmap = np.mean(heatmaps, axis=0) | |
# print(combined_heatmap.shape) | |
# Normalize | |
combined_heatmap = np.maximum(combined_heatmap, 0) | |
combined_heatmap = (combined_heatmap - combined_heatmap.min()) / \ | |
(combined_heatmap.max() - combined_heatmap.min() + 1e-10) | |
return combined_heatmap, pred_class | |
def __del__(self): | |
for handle in self.handles: | |
handle.remove() | |
gradcam = MultiLayerGradCAM(model, target_layers=['layer3', 'layer4']) | |
# Image transformation function | |
def transform_image(image): | |
"""Preprocess the input image.""" | |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] | |
img_size=224 | |
transform = transforms.Compose([ #IMAGENET | |
transforms.Resize(256), # Resize shorter side to 256, keeping aspect ratio | |
transforms.CenterCrop(224), # Crop the center 224x224 region | |
transforms.ToTensor(), # Convert to tensor (scales to [0,1]) | |
transforms.Normalize( # Normalize using ImageNet mean & std | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
return img_tensor | |
# Apply feature filters | |
def apply_filters(image, brightness, contrast, hue): | |
"""Adjust Brightness, Contrast, and Hue of the input image.""" | |
image = image.convert("RGB") # Ensure RGB mode | |
# Adjust brightness | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(brightness) | |
# Adjust contrast | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(contrast) | |
# Adjust hue (convert to HSV, modify, and convert back) | |
image = np.array(image) | |
hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) | |
hsv_image[..., 0] = (hsv_image[..., 0] + hue * 180) % 180 # Adjust hue | |
image = cv2.cvtColor(hsv_image.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
return Image.fromarray(image) | |
# Superimposition function | |
def superimpose_images(base_image, overlay_image, alpha): | |
"""Superimpose overlay_image onto base_image with a given alpha blend.""" | |
if overlay_image is None: | |
return base_image # No overlay, return base image as is | |
# Resize overlay image to match base image | |
overlay_image = overlay_image.resize(base_image.size) | |
# Convert to numpy arrays | |
base_array = np.array(base_image).astype(float) | |
overlay_array = np.array(overlay_image).astype(float) | |
# Blend images | |
blended_array = (1 - alpha) * base_array + alpha * overlay_array | |
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) | |
return Image.fromarray(blended_array) | |
# Prediction function | |
def predict(image, brightness, contrast, hue, overlay_image, alpha): | |
"""Apply filters, superimpose, classify image, and visualize results.""" | |
if image is None: | |
return None, None, None | |
# Apply feature filters | |
processed_image = apply_filters(image, brightness, contrast, hue) | |
# Superimpose overlay image | |
final_image = superimpose_images(processed_image, overlay_image, alpha) | |
# Convert PIL Image to Tensor | |
image_tensor = transform_image(final_image) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0] | |
# pred_class = np.argmax(probabilities) | |
# top_5 = torch.topk(probabilities, 5) | |
heatmap, _ = gradcam.generate(image_tensor) | |
# Create GradCAM overlay | |
final_np = np.array(final_image) | |
heatmap = cv2.resize(heatmap, (final_np.shape[1], final_np.shape[0])) | |
heatmap = np.uint8(255 * heatmap) | |
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
superimposed = cv2.addWeighted(heatmap_rgb, 0.5, final_np, 0.5, 0) | |
gradcam_image = Image.fromarray(superimposed) | |
# Generate Bar Chart | |
with plt.xkcd(): | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
top5_indices = np.argsort(probabilities)[-5:][::-1] # Indices of top 5 probabilities | |
top5_probs = probabilities[top5_indices] | |
top5_labels = [class_labels[i] for i in top5_indices] | |
ax.bar(top5_labels, top5_probs, color='skyblue') | |
ax.set_ylabel("Probability") | |
ax.set_title("Class Probabilities") | |
ax.set_ylim([0, 1]) | |
plt.tight_layout(pad=3) | |
ax.set_xticklabels(top5_labels, rotation=45, ha="right", fontsize=8) | |
for i, v in enumerate(top5_probs): | |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10) | |
return final_image, gradcam_image, fig | |
# Gradio Interface | |
with gr.Blocks() as interface: | |
gr.Markdown("<h2 style='text-align: center;'>Image Classifier with Superimposition & Adjustable Filters</h2>") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Base Image") | |
overlay_input = gr.Image(type="pil", label="Upload Overlay Image (Optional)") | |
brightness = gr.Slider(0.5, 2.0, value=1.0, label="Brightness") | |
contrast = gr.Slider(0.5, 2.0, value=1.0, label="Contrast") | |
hue = gr.Slider(-0.5, 0.5, value=0.0, label="Hue") | |
alpha = gr.Slider(0.0, 1.0, value=0.5, label="Overlay Weight (Alpha)") | |
with gr.Column(): | |
processed_image = gr.Image(label="Final Processed Image") | |
gradcam_output = gr.Image(label="GradCAM Overlay") | |
bar_chart = gr.Plot(label="Class Probabilities") | |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha] | |
outputs = [processed_image, gradcam_output, bar_chart] | |
# Event listeners for real-time updates | |
image_input.change(predict, inputs=inputs, outputs=outputs) | |
overlay_input.change(predict, inputs=inputs, outputs=outputs) | |
brightness.change(predict, inputs=inputs, outputs=outputs) | |
contrast.change(predict, inputs=inputs, outputs=outputs) | |
hue.change(predict, inputs=inputs, outputs=outputs) | |
alpha.change(predict, inputs=inputs, outputs=outputs) | |
interface.launch() |