File size: 5,877 Bytes
3c6fea2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import torch.nn.functional as F
import numpy as np
from transformers import Pipeline,AutoModel
from tqdm import tqdm
class Image3DPredictionPipeline(Pipeline):
"""
A custom pipeline that:
1. Takes in a 3D image: shape (m, n, d).
2. Cuts it into (64, 64, d) tiles with a given stride.
3. Runs the model inference on each tile (model is 2D-only).
4. Reconstructs the predictions into a full-size output.
"""
def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,**kwargs):
super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1)
self.model = model.to(device)
self.device = device
self.tile_size = tile_size
self.stride = stride
self.scale_factor = scale_factor
self.batch_size=64
def preprocess(self, inputs):
"""
inputs: np.ndarray of shape (m, n, d)
This function cuts the input volume into tiles of shape (tile_size, tile_size, d)
with a given stride.
Returns:
tiles: list of np arrays each (tile_size, tile_size, d)
coords: list of (x1, y1, x2, y2) coords
"""
volume = inputs
m, n, d = volume.shape
tiles = []
coords = []
# Extract patches with overlap
for y in range(0, m - self.tile_size + 1, self.stride):
for x in range(0, n - self.tile_size + 1, self.stride):
y1, y2 = y, y + self.tile_size
x1, x2 = x, x + self.tile_size
patch = volume[y1:y2, x1:x2] # shape (64,64,d)
tiles.append(patch)
coords.append((x1, y1, x2, y2))
return tiles, coords, (m, n)
def _forward(self, model_inputs):
"""
model_inputs: a list of patches (B, tile_size, tile_size, d)
The model expects input: (B, C=1, H=tile_size, W=tile_size)
and returns (B, 1, H=tile_size, W=tile_size).
We'll add batching using a for loop. We assume `self.batch_size` is defined.
"""
patches = model_inputs
B = len(patches)
# Convert from list of (tile_size, tile_size, d) to (B, d, tile_size, tile_size)
batch = np.stack([p.transpose(2,0,1) for p in patches], axis=0,dtype=np.float32) # (B, d, tile_size, tile_size)
batch = torch.from_numpy(batch).float()
all_preds = []
# Process in batches to save memory
for start_idx in tqdm(range(0, B, 64)):
end_idx = start_idx + self.batch_size
sub_batch = batch[start_idx:end_idx] # shape: (subB, d, tile_size, tile_size)
# Add channel dimension: (subB, 1, tile_size, tile_size)
sub_batch = sub_batch.unsqueeze(1)
with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'):
sub_y_preds = self.model(sub_batch.to(self.device)) # (subB, 1, tile_size, tile_size)
# Apply sigmoid
sub_y_preds = torch.sigmoid(sub_y_preds)
# Move to CPU and numpy
sub_y_preds = sub_y_preds.detach().cpu().numpy()
# shape (subB, 1, tile_size, tile_size)
all_preds.append(sub_y_preds)
# Concatenate along the batch dimension
y_preds = np.concatenate(all_preds, axis=0) # (B, 1, tile_size, tile_size)
return y_preds
def postprocess(self, model_outputs, coords, full_shape):
"""
model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size)
coords: list of (x1, y1, x2, y2) for each tile
full_shape: (m,n)
We need to:
- Place each tile prediction into a full (m,n) array
- Use the kernel to weight and sum predictions
- Divide by count
- Optionally upsample by scale_factor if required
"""
m, n = full_shape
# We will create mask_pred and mask_count to accumulate predictions
mask_pred = np.zeros((m, n), dtype=np.float32)
mask_count = np.zeros((m, n), dtype=np.float32)
B = model_outputs.shape[0]
# Interpolate (upsample) each prediction if needed
# Using PyTorch interpolate:
preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) # (B,1,64,64)
if self.scale_factor != 1:
preds_tensor = F.interpolate(
preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False
)
# shape after upsample: (B,1,64*scale_factor,64*scale_factor)
preds_tensor = preds_tensor.squeeze(1).numpy() # (B, H_out, W_out)
# Adjust coords for upsampling
out_tile_size = self.tile_size
for i, (x1, y1, x2, y2) in enumerate(coords):
# Adjust coords due to upsampling
y2_up = y1 + out_tile_size
x2_up = x1 + out_tile_size
mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i]
mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32)
mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0)
return mask_pred
def _sanitize_parameters(self,**kwargs):
return {},{},{}
def __call__(self, image: np.ndarray):
"""
Args:
image: np.ndarray of shape (m, n, d) input volume.
Returns:
mask_pred: np.ndarray of shape (m_out, n_out) predicted mask.
"""
tiles, coords, full_shape = self.preprocess(image)
# Process in batches if too large (optional). Here we do a single batch inference for simplicity.
# If large images, consider chunking tiles into smaller batches.
outputs = self._forward(tiles)
mask_pred = self.postprocess(outputs, coords, full_shape)
return mask_pred
|