|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from transformers import Pipeline,AutoModel |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
class InkDetectionPipeline(Pipeline): |
|
""" |
|
A custom pipeline that: |
|
1. Takes in a 3D image: shape (m, n, d). |
|
2. Cuts it into (64, 64, d) tiles with a given stride. |
|
3. Runs the model inference on each tile (model is 3D-to-2D). |
|
4. Reconstructs the predictions into a full-size output. |
|
""" |
|
|
|
def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,batch_size=32,**kwargs): |
|
super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1) |
|
self.model = model.to(device) |
|
self.device = device |
|
self.tile_size = tile_size |
|
self.stride = stride |
|
self.scale_factor = scale_factor |
|
self.batch_size=batch_size |
|
def preprocess(self, inputs): |
|
""" |
|
inputs: np.ndarray of shape (m, n, d) |
|
This function cuts the input volume into tiles of shape (tile_size, tile_size, d) |
|
with a given stride. |
|
Returns: |
|
tiles: list of np arrays each (tile_size, tile_size, d) |
|
coords: list of (x1, y1, x2, y2) coords |
|
""" |
|
volume = inputs |
|
m, n, d = volume.shape |
|
tiles = [] |
|
coords = [] |
|
|
|
for y in range(0, m - self.tile_size + 1, self.stride): |
|
for x in range(0, n - self.tile_size + 1, self.stride): |
|
y1, y2 = y, y + self.tile_size |
|
x1, x2 = x, x + self.tile_size |
|
patch = volume[y1:y2, x1:x2] |
|
tiles.append(patch.transpose(2,0,1)) |
|
coords.append((x1, y1, x2, y2)) |
|
return np.array(tiles,dtype=np.float16), coords, (m, n) |
|
def _forward(self, model_inputs): |
|
""" |
|
model_inputs: a list of patches (B, tile_size, tile_size, d) |
|
The model expects input: (B, C=1, H=tile_size, W=tile_size) |
|
and returns (B, 1, H=tile_size, W=tile_size). |
|
|
|
We'll add batching using a for loop. We assume `self.batch_size` is defined. |
|
""" |
|
|
|
patches = model_inputs |
|
B = len(patches) |
|
all_preds = [] |
|
|
|
for start_idx in tqdm(range(0, B, self.batch_size)): |
|
end_idx = start_idx + self.batch_size |
|
sub_batch = torch.from_numpy(patches[start_idx:end_idx].astype(np.float32)) |
|
|
|
|
|
sub_batch = sub_batch.unsqueeze(1) |
|
|
|
with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'): |
|
sub_y_preds = self.model(sub_batch.to(self.device)) |
|
|
|
|
|
sub_y_preds = torch.sigmoid(sub_y_preds) |
|
|
|
|
|
sub_y_preds = sub_y_preds.detach().cpu().float().numpy() |
|
|
|
|
|
all_preds.append(sub_y_preds) |
|
|
|
|
|
y_preds = np.concatenate(all_preds, axis=0) |
|
|
|
return y_preds |
|
|
|
def postprocess(self, model_outputs, coords, full_shape): |
|
""" |
|
model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size) |
|
coords: list of (x1, y1, x2, y2) for each tile |
|
full_shape: (m,n) |
|
|
|
We need to: |
|
- Place each tile prediction into a full (m,n) array |
|
- Use the kernel to weight and sum predictions |
|
- Divide by count |
|
- Optionally upsample by scale_factor if required |
|
""" |
|
m, n = full_shape |
|
|
|
mask_pred = np.zeros((m, n), dtype=np.float32) |
|
mask_count = np.zeros((m, n), dtype=np.float32) |
|
B = model_outputs.shape[0] |
|
|
|
|
|
preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) |
|
if self.scale_factor != 1: |
|
preds_tensor = F.interpolate( |
|
preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False |
|
) |
|
preds_tensor = preds_tensor.squeeze(1).numpy() |
|
|
|
out_tile_size = self.tile_size |
|
|
|
for i, (x1, y1, x2, y2) in enumerate(coords): |
|
|
|
y2_up = y1 + out_tile_size |
|
x2_up = x1 + out_tile_size |
|
|
|
mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i] |
|
mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32) |
|
|
|
mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0) |
|
|
|
return mask_pred |
|
def _sanitize_parameters(self,**kwargs): |
|
return {},{},{} |
|
def __call__(self, image: np.ndarray): |
|
""" |
|
Args: |
|
image: np.ndarray of shape (m, n, d) input volume. |
|
Returns: |
|
mask_pred: np.ndarray of shape (m_out, n_out) predicted mask. |
|
""" |
|
tiles, coords, full_shape = self.preprocess(image) |
|
|
|
|
|
outputs = self._forward(tiles) |
|
mask_pred = self.postprocess(outputs, coords, full_shape) |
|
return mask_pred |