|
import random |
|
import time |
|
import numpy as np |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import matplotlib.pyplot as plt |
|
|
|
from glob import glob |
|
from PIL import Image |
|
from model.load_model import get_model |
|
from torchvision import transforms |
|
|
|
from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image |
|
|
|
from ultralytics import YOLO |
|
|
|
|
|
import uuid |
|
|
|
|
|
|
|
model_path = "efficientnet-b0-best.pth" |
|
model_name = "efficientnet_b0" |
|
YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt" |
|
classes = ["Healthy", "Resistant", "Susceptible"] |
|
resizing_transforms = transforms.Compose([transforms.CenterCrop(224)]) |
|
|
|
|
|
|
|
def reproduce(seed=42): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
cudnn.deterministic = True |
|
cudnn.benchmark = False |
|
|
|
|
|
def get_grad_cam_results(image, transformed_image, class_index=0): |
|
with GradCAM(model=model, target_layers=target_layers) as cam: |
|
targets = [ClassifierOutputTarget(class_index)] |
|
grayscale_cam = cam( |
|
input_tensor=transformed_image.unsqueeze(0), targets=targets |
|
) |
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
visualization = show_cam_on_image( |
|
np.array(image) / 255.0, grayscale_cam, use_rgb=True |
|
) |
|
return visualization, grayscale_cam |
|
|
|
|
|
def get_backpropagation_results(transformed_image, class_index=0): |
|
transformed_image = transformed_image.unsqueeze(0) |
|
backpropagation = gbp_model(transformed_image, target_category=class_index) |
|
bp_deprocessed = deprocess_image(backpropagation) |
|
return backpropagation, bp_deprocessed |
|
|
|
|
|
def get_guided_gradcam(image, cam_grayscale, bp): |
|
cam_mask = np.expand_dims(cam_grayscale, axis=-1) |
|
cam_mask = np.repeat(cam_mask, 3, axis=-1) |
|
img = show_cam_on_image( |
|
np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False |
|
) |
|
return img |
|
|
|
|
|
def explain_results(image, class_index=0): |
|
transformed_image = image_transform(image) |
|
image = resizing_transforms(image) |
|
|
|
visualization, cam_mask = get_grad_cam_results( |
|
image, transformed_image, class_index |
|
) |
|
backpropagation, bp_deprocessed = get_backpropagation_results( |
|
transformed_image, class_index |
|
) |
|
guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation) |
|
|
|
return visualization, bp_deprocessed, guided_gradcam |
|
|
|
|
|
def make_prediction_and_explain(image): |
|
transformed_image = image_transform(image) |
|
transformed_image = transformed_image.unsqueeze(0) |
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(transformed_image) |
|
output = torch.nn.functional.softmax(output, dim=1) |
|
|
|
predictions = [round(x, 4) * 100 for x in output[0].tolist()] |
|
results = {} |
|
|
|
for i, k in enumerate(classes): |
|
gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i) |
|
|
|
results[k] = { |
|
"original_image": image, |
|
"prediction": f"{k} ({predictions[i]}%)", |
|
"gradcam": gradcam, |
|
"backpropagation": bp_deprocessed, |
|
"guided_gradcam": guided_gradcam, |
|
} |
|
|
|
return results |
|
|
|
|
|
def save_explanation_results(res, path): |
|
fig, ax = plt.subplots(3, 4, figsize=(15, 15)) |
|
for i, (k, v) in enumerate(res.items()): |
|
ax[i, 0].imshow(v["original_image"]) |
|
ax[i, 0].set_title(f"Original Image (class: {v['prediction']}") |
|
ax[i, 0].axis("off") |
|
|
|
ax[i, 1].imshow(v["gradcam"]) |
|
ax[i, 1].set_title("GradCAM") |
|
ax[i, 1].axis("off") |
|
|
|
ax[i, 2].imshow(v["backpropagation"]) |
|
ax[i, 2].set_title("Backpropagation") |
|
ax[i, 2].axis("off") |
|
|
|
ax[i, 3].imshow(v["guided_gradcam"]) |
|
ax[i, 3].set_title("Guided GradCAM") |
|
ax[i, 3].axis("off") |
|
|
|
plt.tight_layout() |
|
plt.savefig(path, bbox_inches="tight") |
|
plt.close(fig) |
|
|
|
|
|
model, image_transform = get_model(model_name) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
|
model.train() |
|
target_layers = [model.conv_head] |
|
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") |
|
|
|
yolo_model = YOLO(YOLO_MODEL_WEIGHTS) |
|
|
|
|
|
def get_results(img_path=None, img_for_testing=None, od=False): |
|
if img_path is None and img_for_testing is None: |
|
raise ValueError("Either img_path or img_for_testing should be provided.") |
|
|
|
if img_path is not None: |
|
image = Image.open(img_path) |
|
|
|
if img_for_testing is not None: |
|
image = Image.fromarray(img_for_testing) |
|
|
|
result_paths = [] |
|
|
|
if od: |
|
results = yolo_model(img_path if img_path else img_for_testing) |
|
for i, result in enumerate(results): |
|
unique_id = uuid.uuid4().hex |
|
save_path = f"/tmp/with-bg-result-{unique_id}.png" |
|
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) |
|
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = make_prediction_and_explain(bbox_image) |
|
save_explanation_results(res, save_path) |
|
|
|
result_paths.append(save_path) |
|
else: |
|
unique_id = uuid.uuid4().hex |
|
save_path = f"/tmp/with-bg-result-{unique_id}.png" |
|
res = make_prediction_and_explain(image) |
|
save_explanation_results(res, save_path) |
|
result_paths.append(save_path) |
|
|
|
return result_paths |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
reproduce() |
|
|
|
model, image_transform = get_model(model_name) |
|
model.load_state_dict(torch.load(model_path)) |
|
model.train() |
|
target_layers = [model.conv_head] |
|
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") |
|
|
|
yolo_model = YOLO(YOLO_MODEL_WEIGHTS) |
|
|
|
for IMAGE_PATH in glob("samples/*"): |
|
start = time.perf_counter() |
|
|
|
results = yolo_model(IMAGE_PATH) |
|
image = Image.open(IMAGE_PATH) |
|
|
|
for i, result in enumerate(results): |
|
save_path = IMAGE_PATH.replace( |
|
"samples/", f"sample-results/with-white-bg-result-{i:02d}-" |
|
) |
|
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) |
|
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = make_prediction_and_explain(bbox_image) |
|
save_explanation_results(res, save_path) |
|
|
|
end = time.perf_counter() - start |
|
print(f"Completed in {end}s") |
|
|