import gradio as gr import timm import torch from PIL import Image import requests from io import BytesIO import numpy as np from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from timm.data import create_transform # List of available timm models MODELS = timm.list_models() # List of available GradCAM methods CAM_METHODS = { "GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAM++": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "EigenCAM": EigenCAM, "FullGrad": FullGrad } def load_model(model_name): model = timm.create_model(model_name, pretrained=True) model.eval() return model def process_image(image_path, model): if image_path.startswith('http'): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) else: image = Image.open(image_path) config = model.pretrained_cfg transform = create_transform( input_size=config['input_size'], crop_pct=config['crop_pct'], mean=config['mean'], std=config['std'], interpolation=config['interpolation'], is_training=False ) tensor = transform(image).unsqueeze(0) return tensor def get_cam_image(model, image, target_layer, cam_method): cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer]) grayscale_cam = cam(input_tensor=image) config = model.pretrained_cfg mean = torch.tensor(config['mean']).view(3, 1, 1) std = torch.tensor(config['std']).view(3, 1, 1) rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy() rgb_img = np.clip(rgb_img, 0, 1) cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True) return Image.fromarray(cam_image) def get_feature_info(model): if hasattr(model, 'feature_info'): return [f['module'] for f in model.feature_info] else: return [] def get_target_layer(model, target_layer_name): if target_layer_name is None: return None try: return model.get_submodule(target_layer_name) except AttributeError: print(f"WARNING: Layer '{target_layer_name}' not found in the model.") return None def explain_image(model_name, image_path, cam_method, feature_module): model = load_model(model_name) image = process_image(image_path, model) target_layer = get_target_layer(model, feature_module) if target_layer is None: # Fallback to the last feature module or last convolutional layer feature_info = get_feature_info(model) if feature_info: target_layer = get_target_layer(model, feature_info[-1]) print(f"Using last feature module: {feature_info[-1]}") else: # Fallback to finding last convolutional layer for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): target_layer = module print(f"Fallback: Using last convolutional layer: {name}") break if target_layer is None: raise ValueError("Could not find a suitable target layer.") cam_image = get_cam_image(model, image, target_layer, cam_method) return cam_image def update_feature_modules(model_name): model = load_model(model_name) feature_modules = get_feature_info(model) return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None) with gr.Blocks() as demo: gr.Markdown("# Explainable AI with timm models") gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.") with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model") image_input = gr.Image(type="filepath", label="Upload Image") cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method") feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)") explain_button = gr.Button("Explain Image") with gr.Column(): output_image = gr.Image(type="pil", label="Explained Image") model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown]) explain_button.click( fn=explain_image, inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown], outputs=[output_image] ) demo.launch()