File size: 5,077 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from typing import Tuple, List

import torch
import numpy as np
import torch.nn
import torch.nn.functional

from atoms_detection.detection import Detection
from atoms_detection.training_model import model_pipeline
from atoms_detection.image_preprocessing import dl_prepro_image
from utils.constants import ModelArgs
from utils.paths import PREDS_PATH


class DLDetection(Detection):
    def __init__(self,
                 model_name: ModelArgs,
                 ckpt_filename: str,
                 dataset_csv: str,
                 threshold: float,
                 detections_path: str,
                 inference_cache_path: str,
                 batch_size: int = 64,
                 ):
        self.model_name = model_name
        self.ckpt_filename = ckpt_filename
        self.device = self.get_torch_device()
        self.batch_size = batch_size

        self.stride = 1
        self.padding = 10
        self.window_size = (21, 21)

        super().__init__(dataset_csv, threshold, detections_path, inference_cache_path)

    @staticmethod
    def get_torch_device():
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return device

    def sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[int, int, np.ndarray]:
        # slide a window across the image
        x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2
        y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2

        for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride):
            for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride):
                # yield the current window
                center_x = x + x_to_center
                center_y = y + y_to_center
                yield center_x-padding, center_y-padding, image[y:y + self.window_size[1], x:x + self.window_size[0]]

    def batch_sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[List[int], List[int], List[np.ndarray]]:
        x_idx_list = []
        y_idx_list = []
        images_list = []
        count = 0
        for _x, _y, _img in self.sliding_window(image, padding=padding):
            x_idx_list.append(_x)
            y_idx_list.append(_y)
            images_list.append(_img)
            count += 1
            if count == self.batch_size:
                yield x_idx_list, y_idx_list, images_list
                x_idx_list = []
                y_idx_list = []
                images_list = []
                count = 0
        if count != 0:
            yield x_idx_list, y_idx_list, images_list

    def padding_image(self, img: np.ndarray) -> np.ndarray:
        image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2))
        image_padded[self.padding:-self.padding, self.padding:-self.padding] = img
        return image_padded

    def load_model(self) -> torch.nn.Module:
        checkpoint = torch.load(self.ckpt_filename, map_location=self.device)

        model = model_pipeline[self.model_name](num_classes=2).to(self.device)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()
        return model

    def images_to_torch_input(self, images_list: List[np.ndarray]) -> torch.Tensor:
        expanded_img = np.expand_dims(images_list, axis=1)
        input_tensor = torch.from_numpy(expanded_img).float()
        input_tensor = input_tensor.to(self.device)
        return input_tensor

    def get_prediction_map(self, padded_image: np.ndarray) -> np.ndarray:
        _shape = padded_image.shape
        pred_map = np.zeros((_shape[0] - self.padding*2, _shape[1] - self.padding*2))
        model = self.load_model()
        for x_idxs, y_idxs, image_crops in self.batch_sliding_window(padded_image, padding=self.padding):
            torch_input = self.images_to_torch_input(image_crops)
            output = model(torch_input)
            pred_prob = torch.nn.functional.softmax(output, 1)
            pred_prob = pred_prob.detach().cpu().numpy()[:, 1]
            pred_map[np.array(y_idxs), np.array(x_idxs)] = pred_prob
        return pred_map

    def image_to_pred_map(self, img: np.ndarray, return_intermediate: bool = False) -> np.ndarray:
        preprocessed_img = dl_prepro_image(img)
        print(f"preprocessed_img.shape: {preprocessed_img.shape}, μ: {np.mean(preprocessed_img)}, σ: {np.std(preprocessed_img)}")
        padded_image = self.padding_image(preprocessed_img)
        print(f"padded_image.shape: {padded_image.shape}, μ: {np.mean(padded_image)}, σ: {np.std(padded_image)}")
        pred_map = self.get_prediction_map(padded_image)
        print(f"pred_map.shape: {pred_map.shape}, μ: {np.mean(pred_map)}, σ: {np.std(pred_map)}")
        if return_intermediate:
            return preprocessed_img, padded_image, pred_map
        return pred_map