File size: 5,620 Bytes
6fec424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ebe1fe
6fec424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn.functional as F
import numpy as np
from transformers import Pipeline,AutoModel
from tqdm import tqdm



class InkDetectionPipeline(Pipeline):
    """
    A custom pipeline that:
    1. Takes in a 3D image: shape (m, n, d).
    2. Cuts it into (64, 64, d) tiles with a given stride.
    3. Runs the model inference on each tile (model is 3D-to-2D).
    4. Reconstructs the predictions into a full-size output.
    """

    def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,batch_size=32,**kwargs):
        super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1)
        self.model = model.to(device)
        self.device = device
        self.tile_size = tile_size
        self.stride = stride
        self.scale_factor = scale_factor
        self.batch_size=batch_size
    def preprocess(self, inputs):
        """
        inputs: np.ndarray of shape (m, n, d)
        This function cuts the input volume into tiles of shape (tile_size, tile_size, d)
        with a given stride.
        Returns:
          tiles: list of np arrays each (tile_size, tile_size, d)
          coords: list of (x1, y1, x2, y2) coords
        """
        volume = inputs
        m, n, d = volume.shape
        tiles = []
        coords = []
        # Extract patches with overlap
        for y in range(0, m - self.tile_size + 1, self.stride):
            for x in range(0, n - self.tile_size + 1, self.stride):
                y1, y2 = y, y + self.tile_size
                x1, x2 = x, x + self.tile_size
                patch = volume[y1:y2, x1:x2]  # shape (64,64,d)
                tiles.append(patch.transpose(2,0,1))
                coords.append((x1, y1, x2, y2))
        return np.array(tiles,dtype=np.float16), coords, (m, n)
    def _forward(self, model_inputs):
        """
        model_inputs: a list of patches (B, tile_size, tile_size, d)
        The model expects input: (B, C=1, H=tile_size, W=tile_size)
        and returns (B, 1, H=tile_size, W=tile_size).

        We'll add batching using a for loop. We assume `self.batch_size` is defined.
        """

        patches = model_inputs
        B = len(patches)
        all_preds = []
        # Process in batches to save memory
        for start_idx in tqdm(range(0, B, self.batch_size)):
            end_idx = start_idx + self.batch_size
            sub_batch = torch.from_numpy(patches[start_idx:end_idx].astype(np.float32))  # shape: (subB, d, tile_size, tile_size)

            # Add channel dimension: (subB, 1, tile_size, tile_size)
            sub_batch = sub_batch.unsqueeze(1)

            with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'):
                sub_y_preds = self.model(sub_batch.to(self.device))  # (subB, 1, tile_size, tile_size)

            # Apply sigmoid
            sub_y_preds = torch.sigmoid(sub_y_preds)

            # Move to CPU and numpy
            sub_y_preds = sub_y_preds.detach().cpu().float().numpy()
            # shape (subB, 1, tile_size, tile_size)

            all_preds.append(sub_y_preds)

        # Concatenate along the batch dimension
        y_preds = np.concatenate(all_preds, axis=0)  # (B, 1, tile_size, tile_size)

        return y_preds

    def postprocess(self, model_outputs, coords, full_shape):
        """
        model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size)
        coords: list of (x1, y1, x2, y2) for each tile
        full_shape: (m,n)

        We need to:
        - Place each tile prediction into a full (m,n) array
        - Use the kernel to weight and sum predictions
        - Divide by count
        - Optionally upsample by scale_factor if required
        """
        m, n = full_shape
        # We will create mask_pred and mask_count to accumulate predictions
        mask_pred = np.zeros((m, n), dtype=np.float32)
        mask_count = np.zeros((m, n), dtype=np.float32)
        B = model_outputs.shape[0]
        # Interpolate (upsample) each prediction if needed
        # Using PyTorch interpolate:
        preds_tensor = torch.from_numpy(model_outputs.astype(np.float32))  # (B,1,64,64)
        if self.scale_factor != 1:
            preds_tensor = F.interpolate(
                preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False
            )
        preds_tensor = preds_tensor.squeeze(1).numpy()  # (B, H_out, W_out)

        out_tile_size = self.tile_size 

        for i, (x1, y1, x2, y2) in enumerate(coords):
            # Adjust coords due to upsampling
            y2_up = y1 + out_tile_size
            x2_up = x1 + out_tile_size

            mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i] 
            mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32)

        mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0)
        
        return mask_pred
    def _sanitize_parameters(self,**kwargs):
        return {},{},{}
    def __call__(self, image: np.ndarray):
        """
        Args:
            image: np.ndarray of shape (m, n, d) input volume.
        Returns:
            mask_pred: np.ndarray of shape (m_out, n_out) predicted mask.
        """
        tiles, coords, full_shape = self.preprocess(image)
        # Process in batches if too large (optional). Here we do a single batch inference for simplicity.
        # If large images, consider chunking tiles into smaller batches.
        outputs = self._forward(tiles)
        mask_pred = self.postprocess(outputs, coords, full_shape)
        return mask_pred