import torch import torch.nn.functional as F import numpy as np from transformers import Pipeline,AutoModel from tqdm import tqdm class InkDetectionPipeline(Pipeline): """ A custom pipeline that: 1. Takes in a 3D image: shape (m, n, d). 2. Cuts it into (64, 64, d) tiles with a given stride. 3. Runs the model inference on each tile (model is 3D-to-2D). 4. Reconstructs the predictions into a full-size output. """ def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,batch_size=32,**kwargs): super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1) self.model = model.to(device) self.device = device self.tile_size = tile_size self.stride = stride self.scale_factor = scale_factor self.batch_size=batch_size def preprocess(self, inputs): """ inputs: np.ndarray of shape (m, n, d) This function cuts the input volume into tiles of shape (tile_size, tile_size, d) with a given stride. Returns: tiles: list of np arrays each (tile_size, tile_size, d) coords: list of (x1, y1, x2, y2) coords """ volume = inputs m, n, d = volume.shape tiles = [] coords = [] # Extract patches with overlap for y in range(0, m - self.tile_size + 1, self.stride): for x in range(0, n - self.tile_size + 1, self.stride): y1, y2 = y, y + self.tile_size x1, x2 = x, x + self.tile_size patch = volume[y1:y2, x1:x2] # shape (64,64,d) tiles.append(patch.transpose(2,0,1)) coords.append((x1, y1, x2, y2)) return np.array(tiles,dtype=np.float16), coords, (m, n) def _forward(self, model_inputs): """ model_inputs: a list of patches (B, tile_size, tile_size, d) The model expects input: (B, C=1, H=tile_size, W=tile_size) and returns (B, 1, H=tile_size, W=tile_size). We'll add batching using a for loop. We assume `self.batch_size` is defined. """ patches = model_inputs B = len(patches) all_preds = [] # Process in batches to save memory for start_idx in tqdm(range(0, B, self.batch_size)): end_idx = start_idx + self.batch_size sub_batch = torch.from_numpy(patches[start_idx:end_idx].astype(np.float32)) # shape: (subB, d, tile_size, tile_size) # Add channel dimension: (subB, 1, tile_size, tile_size) sub_batch = sub_batch.unsqueeze(1) with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'): sub_y_preds = self.model(sub_batch.to(self.device)) # (subB, 1, tile_size, tile_size) # Apply sigmoid sub_y_preds = torch.sigmoid(sub_y_preds) # Move to CPU and numpy sub_y_preds = sub_y_preds.detach().cpu().float().numpy() # shape (subB, 1, tile_size, tile_size) all_preds.append(sub_y_preds) # Concatenate along the batch dimension y_preds = np.concatenate(all_preds, axis=0) # (B, 1, tile_size, tile_size) return y_preds def postprocess(self, model_outputs, coords, full_shape): """ model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size) coords: list of (x1, y1, x2, y2) for each tile full_shape: (m,n) We need to: - Place each tile prediction into a full (m,n) array - Use the kernel to weight and sum predictions - Divide by count - Optionally upsample by scale_factor if required """ m, n = full_shape # We will create mask_pred and mask_count to accumulate predictions mask_pred = np.zeros((m, n), dtype=np.float32) mask_count = np.zeros((m, n), dtype=np.float32) B = model_outputs.shape[0] # Interpolate (upsample) each prediction if needed # Using PyTorch interpolate: preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) # (B,1,64,64) if self.scale_factor != 1: preds_tensor = F.interpolate( preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False ) preds_tensor = preds_tensor.squeeze(1).numpy() # (B, H_out, W_out) out_tile_size = self.tile_size for i, (x1, y1, x2, y2) in enumerate(coords): # Adjust coords due to upsampling y2_up = y1 + out_tile_size x2_up = x1 + out_tile_size mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i] mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32) mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0) return mask_pred def _sanitize_parameters(self,**kwargs): return {},{},{} def __call__(self, image: np.ndarray): """ Args: image: np.ndarray of shape (m, n, d) input volume. Returns: mask_pred: np.ndarray of shape (m_out, n_out) predicted mask. """ tiles, coords, full_shape = self.preprocess(image) # Process in batches if too large (optional). Here we do a single batch inference for simplicity. # If large images, consider chunking tiles into smaller batches. outputs = self._forward(tiles) mask_pred = self.postprocess(outputs, coords, full_shape) return mask_pred