nabeelraza's picture
fix: gpu issue
c9a01d9
raw
history blame
6.86 kB
import random
import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from model.load_model import get_model
from torchvision import transforms
from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
from ultralytics import YOLO
# from rembg import remove
import uuid
# Static variables
model_path = "efficientnet-b0-best.pth"
model_name = "efficientnet_b0"
YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
classes = ["Healthy", "Resistant", "Susceptible"]
resizing_transforms = transforms.Compose([transforms.CenterCrop(224)])
# Function definitions
def reproduce(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def get_grad_cam_results(image, transformed_image, class_index=0):
with GradCAM(model=model, target_layers=target_layers) as cam:
targets = [ClassifierOutputTarget(class_index)]
grayscale_cam = cam(
input_tensor=transformed_image.unsqueeze(0), targets=targets
)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
np.array(image) / 255.0, grayscale_cam, use_rgb=True
)
return visualization, grayscale_cam
def get_backpropagation_results(transformed_image, class_index=0):
transformed_image = transformed_image.unsqueeze(0)
backpropagation = gbp_model(transformed_image, target_category=class_index)
bp_deprocessed = deprocess_image(backpropagation)
return backpropagation, bp_deprocessed
def get_guided_gradcam(image, cam_grayscale, bp):
cam_mask = np.expand_dims(cam_grayscale, axis=-1)
cam_mask = np.repeat(cam_mask, 3, axis=-1)
img = show_cam_on_image(
np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False
)
return img
def explain_results(image, class_index=0):
transformed_image = image_transform(image)
image = resizing_transforms(image)
visualization, cam_mask = get_grad_cam_results(
image, transformed_image, class_index
)
backpropagation, bp_deprocessed = get_backpropagation_results(
transformed_image, class_index
)
guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation)
return visualization, bp_deprocessed, guided_gradcam
def make_prediction_and_explain(image):
transformed_image = image_transform(image)
transformed_image = transformed_image.unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(transformed_image)
output = torch.nn.functional.softmax(output, dim=1)
predictions = [round(x, 4) * 100 for x in output[0].tolist()]
results = {}
for i, k in enumerate(classes):
gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i)
results[k] = {
"original_image": image,
"prediction": f"{k} ({predictions[i]}%)",
"gradcam": gradcam,
"backpropagation": bp_deprocessed,
"guided_gradcam": guided_gradcam,
}
return results
def save_explanation_results(res, path):
fig, ax = plt.subplots(3, 4, figsize=(15, 15))
for i, (k, v) in enumerate(res.items()):
ax[i, 0].imshow(v["original_image"])
ax[i, 0].set_title(f"Original Image (class: {v['prediction']}")
ax[i, 0].axis("off")
ax[i, 1].imshow(v["gradcam"])
ax[i, 1].set_title("GradCAM")
ax[i, 1].axis("off")
ax[i, 2].imshow(v["backpropagation"])
ax[i, 2].set_title("Backpropagation")
ax[i, 2].axis("off")
ax[i, 3].imshow(v["guided_gradcam"])
ax[i, 3].set_title("Guided GradCAM")
ax[i, 3].axis("off")
plt.tight_layout()
plt.savefig(path, bbox_inches="tight")
plt.close(fig)
model, image_transform = get_model(model_name)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.train()
target_layers = [model.conv_head]
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
def get_results(img_path=None, img_for_testing=None):
if img_path is None and img_for_testing is None:
raise ValueError("Either img_path or img_for_testing should be provided.")
if img_path is not None:
results = yolo_model(img_path)
image = Image.open(img_path)
if img_for_testing is not None:
results = yolo_model(img_for_testing)
image = Image.fromarray(img_for_testing)
result_paths = []
for i, result in enumerate(results):
unique_id = uuid.uuid4().hex
save_path = f"/tmp/with-white-bg-result-{unique_id}.png"
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
# bbox_image = remove(bbox_image).convert("RGB")
# bbox_image = Image.fromarray(
# np.where(
# np.array(bbox_image) == [0, 0, 0],
# [255, 255, 255],
# np.array(bbox_image),
# ).astype(np.uint8)
# )
res = make_prediction_and_explain(bbox_image)
save_explanation_results(res, save_path)
result_paths.append(save_path)
return result_paths
if __name__ == "__main__":
# Actual logic
reproduce()
model, image_transform = get_model(model_name)
model.load_state_dict(torch.load(model_path))
model.train()
target_layers = [model.conv_head]
gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
for IMAGE_PATH in glob("samples/*"):
start = time.perf_counter()
results = yolo_model(IMAGE_PATH)
image = Image.open(IMAGE_PATH)
for i, result in enumerate(results):
save_path = IMAGE_PATH.replace(
"samples/", f"sample-results/with-white-bg-result-{i:02d}-"
)
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
# bbox_image = remove(bbox_image).convert("RGB")
# bbox_image = Image.fromarray(
# np.where(
# np.array(bbox_image) == [0, 0, 0],
# [255, 255, 255],
# np.array(bbox_image),
# ).astype(np.uint8)
# )
res = make_prediction_and_explain(bbox_image)
save_explanation_results(res, save_path)
end = time.perf_counter() - start
print(f"Completed in {end}s")