ink_detection_pipeline / ink_detection_pipeline.py
YoussefMoNader's picture
Update ink_detection_pipeline.py
6ebe1fe verified
import torch
import torch.nn.functional as F
import numpy as np
from transformers import Pipeline,AutoModel
from tqdm import tqdm
class InkDetectionPipeline(Pipeline):
"""
A custom pipeline that:
1. Takes in a 3D image: shape (m, n, d).
2. Cuts it into (64, 64, d) tiles with a given stride.
3. Runs the model inference on each tile (model is 3D-to-2D).
4. Reconstructs the predictions into a full-size output.
"""
def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,batch_size=32,**kwargs):
super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1)
self.model = model.to(device)
self.device = device
self.tile_size = tile_size
self.stride = stride
self.scale_factor = scale_factor
self.batch_size=batch_size
def preprocess(self, inputs):
"""
inputs: np.ndarray of shape (m, n, d)
This function cuts the input volume into tiles of shape (tile_size, tile_size, d)
with a given stride.
Returns:
tiles: list of np arrays each (tile_size, tile_size, d)
coords: list of (x1, y1, x2, y2) coords
"""
volume = inputs
m, n, d = volume.shape
tiles = []
coords = []
# Extract patches with overlap
for y in range(0, m - self.tile_size + 1, self.stride):
for x in range(0, n - self.tile_size + 1, self.stride):
y1, y2 = y, y + self.tile_size
x1, x2 = x, x + self.tile_size
patch = volume[y1:y2, x1:x2] # shape (64,64,d)
tiles.append(patch.transpose(2,0,1))
coords.append((x1, y1, x2, y2))
return np.array(tiles,dtype=np.float16), coords, (m, n)
def _forward(self, model_inputs):
"""
model_inputs: a list of patches (B, tile_size, tile_size, d)
The model expects input: (B, C=1, H=tile_size, W=tile_size)
and returns (B, 1, H=tile_size, W=tile_size).
We'll add batching using a for loop. We assume `self.batch_size` is defined.
"""
patches = model_inputs
B = len(patches)
all_preds = []
# Process in batches to save memory
for start_idx in tqdm(range(0, B, self.batch_size)):
end_idx = start_idx + self.batch_size
sub_batch = torch.from_numpy(patches[start_idx:end_idx].astype(np.float32)) # shape: (subB, d, tile_size, tile_size)
# Add channel dimension: (subB, 1, tile_size, tile_size)
sub_batch = sub_batch.unsqueeze(1)
with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'):
sub_y_preds = self.model(sub_batch.to(self.device)) # (subB, 1, tile_size, tile_size)
# Apply sigmoid
sub_y_preds = torch.sigmoid(sub_y_preds)
# Move to CPU and numpy
sub_y_preds = sub_y_preds.detach().cpu().float().numpy()
# shape (subB, 1, tile_size, tile_size)
all_preds.append(sub_y_preds)
# Concatenate along the batch dimension
y_preds = np.concatenate(all_preds, axis=0) # (B, 1, tile_size, tile_size)
return y_preds
def postprocess(self, model_outputs, coords, full_shape):
"""
model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size)
coords: list of (x1, y1, x2, y2) for each tile
full_shape: (m,n)
We need to:
- Place each tile prediction into a full (m,n) array
- Use the kernel to weight and sum predictions
- Divide by count
- Optionally upsample by scale_factor if required
"""
m, n = full_shape
# We will create mask_pred and mask_count to accumulate predictions
mask_pred = np.zeros((m, n), dtype=np.float32)
mask_count = np.zeros((m, n), dtype=np.float32)
B = model_outputs.shape[0]
# Interpolate (upsample) each prediction if needed
# Using PyTorch interpolate:
preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) # (B,1,64,64)
if self.scale_factor != 1:
preds_tensor = F.interpolate(
preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False
)
preds_tensor = preds_tensor.squeeze(1).numpy() # (B, H_out, W_out)
out_tile_size = self.tile_size
for i, (x1, y1, x2, y2) in enumerate(coords):
# Adjust coords due to upsampling
y2_up = y1 + out_tile_size
x2_up = x1 + out_tile_size
mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i]
mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32)
mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0)
return mask_pred
def _sanitize_parameters(self,**kwargs):
return {},{},{}
def __call__(self, image: np.ndarray):
"""
Args:
image: np.ndarray of shape (m, n, d) input volume.
Returns:
mask_pred: np.ndarray of shape (m_out, n_out) predicted mask.
"""
tiles, coords, full_shape = self.preprocess(image)
# Process in batches if too large (optional). Here we do a single batch inference for simplicity.
# If large images, consider chunking tiles into smaller batches.
outputs = self._forward(tiles)
mask_pred = self.postprocess(outputs, coords, full_shape)
return mask_pred