Nabeel Raza
add: OD option
a0630af
import random
import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from model.load_model import get_model
from torchvision import transforms
from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
from ultralytics import YOLO
# from rembg import remove
import uuid
# Static variables
model_path = "efficientnet-b0-best.pth"
model_name = "efficientnet_b0"
YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
classes = ["Healthy", "Resistant", "Susceptible"]
resizing_transforms = transforms.Compose([transforms.CenterCrop(224)])
# Function definitions
def reproduce(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def get_grad_cam_results(image, transformed_image, class_index=0):
with GradCAM(model=model, target_layers=target_layers) as cam:
targets = [ClassifierOutputTarget(class_index)]
grayscale_cam = cam(
input_tensor=transformed_image.unsqueeze(0), targets=targets
)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
np.array(image) / 255.0, grayscale_cam, use_rgb=True
)
return visualization, grayscale_cam
def get_backpropagation_results(transformed_image, class_index=0):
transformed_image = transformed_image.unsqueeze(0)
backpropagation = gbp_model(transformed_image, target_category=class_index)
bp_deprocessed = deprocess_image(backpropagation)
return backpropagation, bp_deprocessed
def get_guided_gradcam(image, cam_grayscale, bp):
cam_mask = np.expand_dims(cam_grayscale, axis=-1)
cam_mask = np.repeat(cam_mask, 3, axis=-1)
img = show_cam_on_image(
np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False
)
return img
def explain_results(image, class_index=0):
transformed_image = image_transform(image)
image = resizing_transforms(image)
visualization, cam_mask = get_grad_cam_results(
image, transformed_image, class_index
)
backpropagation, bp_deprocessed = get_backpropagation_results(
transformed_image, class_index
)
guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation)
return visualization, bp_deprocessed, guided_gradcam
def make_prediction_and_explain(image):
transformed_image = image_transform(image)
transformed_image = transformed_image.unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(transformed_image)
output = torch.nn.functional.softmax(output, dim=1)
predictions = [round(x, 4) * 100 for x in output[0].tolist()]
results = {}
for i, k in enumerate(classes):
gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i)
results[k] = {
"original_image": image,
"prediction": f"{k} ({predictions[i]}%)",
"gradcam": gradcam,
"backpropagation": bp_deprocessed,
"guided_gradcam": guided_gradcam,
}
return results
def save_explanation_results(res, path):
fig, ax = plt.subplots(3, 4, figsize=(15, 15))
for i, (k, v) in enumerate(res.items()):
ax[i, 0].imshow(v["original_image"])
ax[i, 0].set_title(f"Original Image (class: {v['prediction']}")
ax[i, 0].axis("off")
ax[i, 1].imshow(v["gradcam"])
ax[i, 1].set_title("GradCAM")
ax[i, 1].axis("off")
ax[i, 2].imshow(v["backpropagation"])
ax[i, 2].set_title("Backpropagation")
ax[i, 2].axis("off")
ax[i, 3].imshow(v["guided_gradcam"])
ax[i, 3].set_title("Guided GradCAM")
ax[i, 3].axis("off")
plt.tight_layout()
plt.savefig(path, bbox_inches="tight")
plt.close(fig)
model, image_transform = get_model(model_name)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.train()
target_layers = [model.conv_head]
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
def get_results(img_path=None, img_for_testing=None, od=False):
if img_path is None and img_for_testing is None:
raise ValueError("Either img_path or img_for_testing should be provided.")
if img_path is not None:
image = Image.open(img_path)
if img_for_testing is not None:
image = Image.fromarray(img_for_testing)
result_paths = []
if od:
results = yolo_model(img_path if img_path else img_for_testing)
for i, result in enumerate(results):
unique_id = uuid.uuid4().hex
save_path = f"/tmp/with-bg-result-{unique_id}.png"
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
# bbox_image = remove(bbox_image).convert("RGB")
# bbox_image = Image.fromarray(
# np.where(
# np.array(bbox_image) == [0, 0, 0],
# [255, 255, 255],
# np.array(bbox_image),
# ).astype(np.uint8)
# )
res = make_prediction_and_explain(bbox_image)
save_explanation_results(res, save_path)
result_paths.append(save_path)
else:
unique_id = uuid.uuid4().hex
save_path = f"/tmp/with-bg-result-{unique_id}.png"
res = make_prediction_and_explain(image)
save_explanation_results(res, save_path)
result_paths.append(save_path)
return result_paths
if __name__ == "__main__":
# Actual logic
reproduce()
model, image_transform = get_model(model_name)
model.load_state_dict(torch.load(model_path))
model.train()
target_layers = [model.conv_head]
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
for IMAGE_PATH in glob("samples/*"):
start = time.perf_counter()
results = yolo_model(IMAGE_PATH)
image = Image.open(IMAGE_PATH)
for i, result in enumerate(results):
save_path = IMAGE_PATH.replace(
"samples/", f"sample-results/with-white-bg-result-{i:02d}-"
)
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
# bbox_image = remove(bbox_image).convert("RGB")
# bbox_image = Image.fromarray(
# np.where(
# np.array(bbox_image) == [0, 0, 0],
# [255, 255, 255],
# np.array(bbox_image),
# ).astype(np.uint8)
# )
res = make_prediction_and_explain(bbox_image)
save_explanation_results(res, save_path)
end = time.perf_counter() - start
print(f"Completed in {end}s")