YoussefMoNader commited on
Commit
3c6fea2
·
verified ·
1 Parent(s): 6873be0

Upload Image3DPredictionPipeline

Browse files
Files changed (2) hide show
  1. config.json +9 -0
  2. ink_detection_pipeline.py +146 -0
config.json CHANGED
@@ -7,6 +7,15 @@
7
  "AutoConfig": "YoussefMoNader/timesformer-test4--timesformer_config.TimesformerScrollprizeConfig",
8
  "AutoModel": "YoussefMoNader/timesformer-test4--timesformer_model.TimesformerScrollprizeModel"
9
  },
 
 
 
 
 
 
 
 
 
10
  "depth": 8,
11
  "dim": 512,
12
  "n_heads": 6,
 
7
  "AutoConfig": "YoussefMoNader/timesformer-test4--timesformer_config.TimesformerScrollprizeConfig",
8
  "AutoModel": "YoussefMoNader/timesformer-test4--timesformer_model.TimesformerScrollprizeModel"
9
  },
10
+ "custom_pipelines": {
11
+ "ink-detection": {
12
+ "impl": "ink_detection_pipeline.Image3DPredictionPipeline",
13
+ "pt": [
14
+ "AutoModel"
15
+ ],
16
+ "tf": []
17
+ }
18
+ },
19
  "depth": 8,
20
  "dim": 512,
21
  "n_heads": 6,
ink_detection_pipeline.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from transformers import Pipeline,AutoModel
5
+ from tqdm import tqdm
6
+
7
+
8
+
9
+ class Image3DPredictionPipeline(Pipeline):
10
+ """
11
+ A custom pipeline that:
12
+ 1. Takes in a 3D image: shape (m, n, d).
13
+ 2. Cuts it into (64, 64, d) tiles with a given stride.
14
+ 3. Runs the model inference on each tile (model is 2D-only).
15
+ 4. Reconstructs the predictions into a full-size output.
16
+ """
17
+
18
+ def __init__(self, model, device='cuda', tile_size=64, stride=32, scale_factor=16,**kwargs):
19
+ super().__init__(model=model, tokenizer=None, device=0 if device=='cuda' else -1)
20
+ self.model = model.to(device)
21
+ self.device = device
22
+ self.tile_size = tile_size
23
+ self.stride = stride
24
+ self.scale_factor = scale_factor
25
+ self.batch_size=64
26
+ def preprocess(self, inputs):
27
+ """
28
+ inputs: np.ndarray of shape (m, n, d)
29
+ This function cuts the input volume into tiles of shape (tile_size, tile_size, d)
30
+ with a given stride.
31
+ Returns:
32
+ tiles: list of np arrays each (tile_size, tile_size, d)
33
+ coords: list of (x1, y1, x2, y2) coords
34
+ """
35
+ volume = inputs
36
+ m, n, d = volume.shape
37
+ tiles = []
38
+ coords = []
39
+
40
+ # Extract patches with overlap
41
+ for y in range(0, m - self.tile_size + 1, self.stride):
42
+ for x in range(0, n - self.tile_size + 1, self.stride):
43
+ y1, y2 = y, y + self.tile_size
44
+ x1, x2 = x, x + self.tile_size
45
+ patch = volume[y1:y2, x1:x2] # shape (64,64,d)
46
+ tiles.append(patch)
47
+ coords.append((x1, y1, x2, y2))
48
+ return tiles, coords, (m, n)
49
+ def _forward(self, model_inputs):
50
+ """
51
+ model_inputs: a list of patches (B, tile_size, tile_size, d)
52
+ The model expects input: (B, C=1, H=tile_size, W=tile_size)
53
+ and returns (B, 1, H=tile_size, W=tile_size).
54
+
55
+ We'll add batching using a for loop. We assume `self.batch_size` is defined.
56
+ """
57
+
58
+ patches = model_inputs
59
+ B = len(patches)
60
+
61
+ # Convert from list of (tile_size, tile_size, d) to (B, d, tile_size, tile_size)
62
+ batch = np.stack([p.transpose(2,0,1) for p in patches], axis=0,dtype=np.float32) # (B, d, tile_size, tile_size)
63
+ batch = torch.from_numpy(batch).float()
64
+
65
+ all_preds = []
66
+ # Process in batches to save memory
67
+ for start_idx in tqdm(range(0, B, 64)):
68
+ end_idx = start_idx + self.batch_size
69
+ sub_batch = batch[start_idx:end_idx] # shape: (subB, d, tile_size, tile_size)
70
+
71
+ # Add channel dimension: (subB, 1, tile_size, tile_size)
72
+ sub_batch = sub_batch.unsqueeze(1)
73
+
74
+ with torch.no_grad(), torch.autocast(self.device if self.device == 'cuda' else 'cpu'):
75
+ sub_y_preds = self.model(sub_batch.to(self.device)) # (subB, 1, tile_size, tile_size)
76
+
77
+ # Apply sigmoid
78
+ sub_y_preds = torch.sigmoid(sub_y_preds)
79
+
80
+ # Move to CPU and numpy
81
+ sub_y_preds = sub_y_preds.detach().cpu().numpy()
82
+ # shape (subB, 1, tile_size, tile_size)
83
+
84
+ all_preds.append(sub_y_preds)
85
+
86
+ # Concatenate along the batch dimension
87
+ y_preds = np.concatenate(all_preds, axis=0) # (B, 1, tile_size, tile_size)
88
+
89
+ return y_preds
90
+
91
+ def postprocess(self, model_outputs, coords, full_shape):
92
+ """
93
+ model_outputs: np.ndarray of shape (B, 1, tile_size, tile_size)
94
+ coords: list of (x1, y1, x2, y2) for each tile
95
+ full_shape: (m,n)
96
+
97
+ We need to:
98
+ - Place each tile prediction into a full (m,n) array
99
+ - Use the kernel to weight and sum predictions
100
+ - Divide by count
101
+ - Optionally upsample by scale_factor if required
102
+ """
103
+ m, n = full_shape
104
+ # We will create mask_pred and mask_count to accumulate predictions
105
+ mask_pred = np.zeros((m, n), dtype=np.float32)
106
+ mask_count = np.zeros((m, n), dtype=np.float32)
107
+ B = model_outputs.shape[0]
108
+ # Interpolate (upsample) each prediction if needed
109
+ # Using PyTorch interpolate:
110
+ preds_tensor = torch.from_numpy(model_outputs.astype(np.float32)) # (B,1,64,64)
111
+ if self.scale_factor != 1:
112
+ preds_tensor = F.interpolate(
113
+ preds_tensor, scale_factor=self.scale_factor, mode='bilinear', align_corners=False
114
+ )
115
+ # shape after upsample: (B,1,64*scale_factor,64*scale_factor)
116
+ preds_tensor = preds_tensor.squeeze(1).numpy() # (B, H_out, W_out)
117
+
118
+ # Adjust coords for upsampling
119
+ out_tile_size = self.tile_size
120
+
121
+ for i, (x1, y1, x2, y2) in enumerate(coords):
122
+ # Adjust coords due to upsampling
123
+ y2_up = y1 + out_tile_size
124
+ x2_up = x1 + out_tile_size
125
+
126
+ mask_pred[y1:y2_up, x1:x2_up] += preds_tensor[i]
127
+ mask_count[y1:y2_up, x1:x2_up] += np.ones((out_tile_size, out_tile_size), dtype=np.float32)
128
+
129
+ mask_pred = np.divide(mask_pred, mask_count, out=np.zeros_like(mask_pred), where=mask_count!=0)
130
+
131
+ return mask_pred
132
+ def _sanitize_parameters(self,**kwargs):
133
+ return {},{},{}
134
+ def __call__(self, image: np.ndarray):
135
+ """
136
+ Args:
137
+ image: np.ndarray of shape (m, n, d) input volume.
138
+ Returns:
139
+ mask_pred: np.ndarray of shape (m_out, n_out) predicted mask.
140
+ """
141
+ tiles, coords, full_shape = self.preprocess(image)
142
+ # Process in batches if too large (optional). Here we do a single batch inference for simplicity.
143
+ # If large images, consider chunking tiles into smaller batches.
144
+ outputs = self._forward(tiles)
145
+ mask_pred = self.postprocess(outputs, coords, full_shape)
146
+ return mask_pred