You-Py commited on
Commit
935ee9e
·
verified ·
1 Parent(s): 23f6916

Create Usage prediction.py

Browse files
Files changed (1) hide show
  1. Usage prediction.py +424 -0
Usage prediction.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch.nn as nn
8
+ from torchvision import models
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from typing import Dict, List, Tuple, Optional, Union
13
+ from dataclasses import dataclass
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+ # ----------------------------
18
+ # Configuration
19
+ # ----------------------------
20
+ @dataclass
21
+ class InferenceConfig:
22
+ # Model Configuration
23
+ model_name: str = "resnet34"
24
+ embedding_dim: int = 128
25
+ normalize_embeddings: bool = True
26
+ checkpoint_path: str = "../../model/models_checkpoints/best_model.pth"
27
+
28
+ # Inference Settings
29
+ batch_size: int = 32
30
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
31
+ distance_threshold: float = 0.5 # Will be loaded from checkpoint
32
+
33
+ # Data Settings
34
+ remove_bg: bool = False
35
+ num_workers: int = 4
36
+
37
+ # Global configuration
38
+ CONFIG = InferenceConfig()
39
+
40
+ # ----------------------------
41
+ # Model Architecture (Same as training)
42
+ # ----------------------------
43
+ class ResNetBackbone(nn.Module):
44
+ """ResNet backbone feature extractor."""
45
+
46
+ def __init__(self, model_name: str = "resnet34"):
47
+ super().__init__()
48
+
49
+ if model_name == "resnet18":
50
+ self.resnet = models.resnet18(weights=None)
51
+ elif model_name == "resnet34":
52
+ self.resnet = models.resnet34(weights=None)
53
+ elif model_name == "resnet50":
54
+ self.resnet = models.resnet50(weights=None)
55
+ else:
56
+ raise ValueError(f"Unsupported model_name: {model_name}")
57
+
58
+ # Remove the fully connected layer
59
+ self.resnet.fc = nn.Identity()
60
+
61
+ # Get output dimension
62
+ with torch.no_grad():
63
+ dummy = torch.randn(1, 3, 224, 224)
64
+ self.output_dim = self.resnet(dummy).shape[1]
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ return self.resnet(x)
68
+
69
+ class AdvancedEmbeddingHead(nn.Module):
70
+ """Embedding head to project features to embedding space."""
71
+
72
+ def __init__(self, input_dim: int, embedding_dim: int, dropout: float = 0.5):
73
+ super().__init__()
74
+
75
+ self.input_dim = input_dim
76
+ self.embedding_dim = embedding_dim
77
+
78
+ if input_dim > embedding_dim * 4:
79
+ hidden_dim = max(embedding_dim * 2, input_dim // 4)
80
+ self.layers = nn.Sequential(
81
+ nn.Linear(input_dim, hidden_dim),
82
+ nn.LayerNorm(hidden_dim),
83
+ nn.GELU(),
84
+ nn.Dropout(dropout),
85
+
86
+ nn.Linear(hidden_dim, embedding_dim * 2),
87
+ nn.LayerNorm(embedding_dim * 2),
88
+ nn.GELU(),
89
+ nn.Dropout(dropout / 2),
90
+
91
+ nn.Linear(embedding_dim * 2, embedding_dim),
92
+ nn.LayerNorm(embedding_dim)
93
+ )
94
+ else:
95
+ self.layers = nn.Sequential(
96
+ nn.Linear(input_dim, embedding_dim),
97
+ nn.LayerNorm(embedding_dim)
98
+ )
99
+
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ x = x.flatten(1)
102
+ return self.layers(x)
103
+
104
+ class SiameseSignatureNetwork(nn.Module):
105
+ """Siamese network for signature verification."""
106
+
107
+ def __init__(self, config: InferenceConfig = CONFIG):
108
+ super().__init__()
109
+ self.config = config
110
+
111
+ # Initialize backbone
112
+ self.backbone = ResNetBackbone(model_name=config.model_name)
113
+ backbone_dim = self.backbone.output_dim
114
+
115
+ # Initialize embedding head
116
+ self.embedding_head = AdvancedEmbeddingHead(
117
+ input_dim=backbone_dim,
118
+ embedding_dim=config.embedding_dim,
119
+ dropout=0.0 # No dropout during inference
120
+ )
121
+
122
+ self.normalize_embeddings = config.normalize_embeddings
123
+ self.distance_threshold = config.distance_threshold
124
+
125
+ def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ """Forward pass for inference."""
127
+ # Extract features
128
+ f1 = self.backbone(img1)
129
+ f2 = self.backbone(img2)
130
+
131
+ # Get embeddings
132
+ emb1 = self.embedding_head(f1)
133
+ emb2 = self.embedding_head(f2)
134
+
135
+ # Normalize if configured
136
+ if self.normalize_embeddings:
137
+ emb1 = F.normalize(emb1, p=2, dim=1)
138
+ emb2 = F.normalize(emb2, p=2, dim=1)
139
+
140
+ return emb1, emb2
141
+
142
+ def predict_pair(self, img1: torch.Tensor, img2: torch.Tensor,
143
+ threshold: Optional[float] = None) -> Dict[str, torch.Tensor]:
144
+ """Predict similarity between image pairs."""
145
+ self.eval()
146
+ with torch.no_grad():
147
+ emb1, emb2 = self(img1, img2)
148
+ distances = F.pairwise_distance(emb1, emb2)
149
+
150
+ thresh = threshold if threshold is not None else self.distance_threshold
151
+ predictions = (distances < thresh).long()
152
+
153
+ # Convert distance to similarity score (0-1, higher is more similar)
154
+ similarities = 1.0 / (1.0 + distances)
155
+
156
+ return {
157
+ 'predictions': predictions,
158
+ 'distances': distances,
159
+ 'similarities': similarities,
160
+ 'threshold': torch.tensor(thresh)
161
+ }
162
+
163
+ # ----------------------------
164
+ # Dataset for Batch Prediction
165
+ # ----------------------------
166
+ class PredictionDataset(Dataset):
167
+ """Dataset for batch prediction from Excel."""
168
+
169
+ def __init__(self, excel_path: str, image_folder: str, config: InferenceConfig = CONFIG):
170
+ self.image_folder = image_folder
171
+ self.config = config
172
+ self.data = pd.read_excel(excel_path)
173
+ self.transform = self._get_transforms()
174
+
175
+ # Check required columns
176
+ required_cols = ['image_1_path', 'image_2_path']
177
+ missing_cols = [col for col in required_cols if col not in self.data.columns]
178
+ if missing_cols:
179
+ raise ValueError(f"Missing required columns: {missing_cols}")
180
+
181
+ def _get_transforms(self) -> transforms.Compose:
182
+ """Get image transforms for inference."""
183
+ return transforms.Compose([
184
+ transforms.Resize((224, 224)),
185
+ transforms.ToTensor(),
186
+ transforms.Normalize(
187
+ mean=[0.485, 0.456, 0.406],
188
+ std=[0.229, 0.224, 0.225]
189
+ )
190
+ ])
191
+
192
+ def __len__(self) -> int:
193
+ return len(self.data)
194
+
195
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
196
+ """Return image pair and index."""
197
+ row = self.data.iloc[idx]
198
+
199
+ img1 = self._load_image(row['image_1_path'])
200
+ img2 = self._load_image(row['image_2_path'])
201
+
202
+ return img1, img2, idx
203
+
204
+ def _load_image(self, image_path: str) -> torch.Tensor:
205
+ """Load and transform image."""
206
+ image = replace_background_with_white(
207
+ image_path, self.image_folder,
208
+ remove_bg=self.config.remove_bg
209
+ )
210
+ return self.transform(image)
211
+
212
+ # ----------------------------
213
+ # Image Processing
214
+ # ----------------------------
215
+ def estimate_background_color_pil(image: Image.Image, border_width: int = 10,
216
+ method: str = "median") -> np.ndarray:
217
+ """Estimate background color from image borders."""
218
+ if image.mode != 'RGB':
219
+ image = image.convert('RGB')
220
+
221
+ np_img = np.array(image)
222
+ h, w, _ = np_img.shape
223
+
224
+ # Extract border pixels
225
+ top = np_img[:border_width, :, :].reshape(-1, 3)
226
+ bottom = np_img[-border_width:, :, :].reshape(-1, 3)
227
+ left = np_img[:, :border_width, :].reshape(-1, 3)
228
+ right = np_img[:, -border_width:, :].reshape(-1, 3)
229
+
230
+ all_border_pixels = np.concatenate([top, bottom, left, right], axis=0)
231
+
232
+ if method == "mean":
233
+ return np.mean(all_border_pixels, axis=0).astype(np.uint8)
234
+ else:
235
+ return np.median(all_border_pixels, axis=0).astype(np.uint8)
236
+
237
+ def replace_background_with_white(image_name: str, folder_img: str,
238
+ tolerance: int = 40, method: str = "median",
239
+ remove_bg: bool = False) -> Image.Image:
240
+ """Replace background with white based on border color estimation."""
241
+ image_path = os.path.join(folder_img, image_name)
242
+ image = Image.open(image_path).convert("RGB")
243
+
244
+ if not remove_bg:
245
+ return image
246
+
247
+ np_img = np.array(image)
248
+ bg_color = estimate_background_color_pil(image, method=method)
249
+
250
+ # Create mask for background pixels
251
+ diff = np.abs(np_img.astype(np.int32) - bg_color.astype(np.int32))
252
+ mask = np.all(diff < tolerance, axis=2)
253
+
254
+ # Replace background with white
255
+ result = np_img.copy()
256
+ result[mask] = [255, 255, 255]
257
+
258
+ return Image.fromarray(result)
259
+
260
+ # ----------------------------
261
+ # Main Prediction Class
262
+ # ----------------------------
263
+ class SignatureVerifier:
264
+ """Main class for signature verification predictions."""
265
+
266
+ def __init__(self, config: InferenceConfig = CONFIG):
267
+ self.config = config
268
+ self.device = torch.device(config.device)
269
+ self.model = self._load_model()
270
+ self.transform = self._get_transforms()
271
+
272
+ def _get_transforms(self) -> transforms.Compose:
273
+ """Get image transforms."""
274
+ return transforms.Compose([
275
+ transforms.Resize((224, 224)),
276
+ transforms.ToTensor(),
277
+ transforms.Normalize(
278
+ mean=[0.485, 0.456, 0.406],
279
+ std=[0.229, 0.224, 0.225]
280
+ )
281
+ ])
282
+
283
+ def _load_model(self) -> SiameseSignatureNetwork:
284
+ """Load model from checkpoint."""
285
+ print(f"Loading model from: {self.config.checkpoint_path}")
286
+
287
+ # Initialize model
288
+ model = SiameseSignatureNetwork(self.config)
289
+
290
+ # Load checkpoint
291
+ checkpoint = torch.load(self.config.checkpoint_path, map_location=self.device, weights_only=False)
292
+
293
+ # Load model state
294
+ if 'model_state_dict' in checkpoint:
295
+ model.load_state_dict(checkpoint['model_state_dict'])
296
+ else:
297
+ # If checkpoint is just the state dict
298
+ model.load_state_dict(checkpoint)
299
+
300
+ # Load threshold if available
301
+ if 'prediction_threshold' in checkpoint:
302
+ model.distance_threshold = checkpoint['prediction_threshold']
303
+ print(f"Loaded threshold: {model.distance_threshold:.4f}")
304
+
305
+ # Load best EER if available
306
+ if 'best_eer' in checkpoint:
307
+ print(f"Model best EER: {checkpoint['best_eer']:.4f}")
308
+
309
+ model = model.to(self.device)
310
+ model.eval()
311
+
312
+ print("Model loaded successfully!")
313
+ return model
314
+
315
+ def predict_single_pair(self, image1_path: str, image2_path: str,
316
+ image_folder: str = "") -> Dict[str, float]:
317
+ """Predict similarity for a single pair of images."""
318
+ # Load images
319
+ img1 = replace_background_with_white(
320
+ image1_path, image_folder, remove_bg=self.config.remove_bg
321
+ )
322
+ img2 = replace_background_with_white(
323
+ image2_path, image_folder, remove_bg=self.config.remove_bg
324
+ )
325
+
326
+ # Transform
327
+ img1_tensor = self.transform(img1).unsqueeze(0).to(self.device)
328
+ img2_tensor = self.transform(img2).unsqueeze(0).to(self.device)
329
+
330
+ # Predict
331
+ results = self.model.predict_pair(img1_tensor, img2_tensor)
332
+
333
+ return {
334
+ 'is_genuine': bool(results['predictions'].item()),
335
+ 'distance': float(results['distances'].item()),
336
+ 'similarity_score': float(results['similarities'].item()),
337
+ 'threshold': float(results['threshold'].item())
338
+ }
339
+
340
+ def predict_from_excel(self, excel_path: str, image_folder: str,
341
+ output_path: Optional[str] = None) -> pd.DataFrame:
342
+ """Batch prediction from Excel file."""
343
+ # Create dataset and dataloader
344
+ dataset = PredictionDataset(excel_path, image_folder, self.config)
345
+ dataloader = DataLoader(
346
+ dataset,
347
+ batch_size=self.config.batch_size,
348
+ shuffle=False,
349
+ num_workers=self.config.num_workers,
350
+ pin_memory=True
351
+ )
352
+
353
+ # Prediction storage
354
+ all_predictions = []
355
+ all_distances = []
356
+ all_similarities = []
357
+
358
+ # Predict in batches
359
+ print(f"Processing {len(dataset)} pairs...")
360
+ with torch.no_grad():
361
+ for img1_batch, img2_batch, indices in tqdm(dataloader):
362
+ img1_batch = img1_batch.to(self.device)
363
+ img2_batch = img2_batch.to(self.device)
364
+
365
+ results = self.model.predict_pair(img1_batch, img2_batch)
366
+
367
+ all_predictions.extend(results['predictions'].cpu().numpy())
368
+ all_distances.extend(results['distances'].cpu().numpy())
369
+ all_similarities.extend(results['similarities'].cpu().numpy())
370
+
371
+ # Create results dataframe
372
+ results_df = dataset.data.copy()
373
+ results_df['prediction'] = all_predictions
374
+ results_df['is_genuine'] = results_df['prediction'].astype(bool)
375
+ results_df['distance'] = all_distances
376
+ results_df['similarity_score'] = all_similarities
377
+ results_df['threshold'] = self.model.distance_threshold
378
+
379
+ # Save if output path provided
380
+ if output_path:
381
+ results_df.to_excel(output_path, index=False)
382
+ print(f"Results saved to: {output_path}")
383
+
384
+ return results_df
385
+
386
+ def update_threshold(self, new_threshold: float):
387
+ """Update the decision threshold."""
388
+ self.model.distance_threshold = new_threshold
389
+ print(f"Threshold updated to: {new_threshold:.4f}")
390
+
391
+ # Initialize verifier
392
+ config = InferenceConfig(
393
+ checkpoint_path="../../../../model/models_checkpoints/fa7e1bdc01814016ac8220bfbf1eb691/best_model.pth",
394
+ batch_size=32,
395
+ device="cuda" if torch.cuda.is_available() else "cpu"
396
+ )
397
+
398
+ verifier = SignatureVerifier(config)
399
+
400
+ '''
401
+ # Example 1: Single pair prediction
402
+ print("\n--- Single Pair Prediction ---")
403
+ result = verifier.predict_single_pair(
404
+ image1_path="sig1.png",
405
+ image2_path="sig2.png",
406
+ image_folder="../../data/classify/preprared_data/images/"
407
+ )
408
+ '''
409
+
410
+ # Example 2: Batch prediction from Excel
411
+ print("\n--- Batch Prediction from Excel ---")
412
+ results_df = verifier.predict_from_excel(
413
+ excel_path="../../../../data/classify/preprared_data/labels/test_pairs_balanced_v12.xlsx",
414
+ image_folder="../../../../data/classify/preprared_data/images/",
415
+ output_path="./predictions_output.xlsx"
416
+ )
417
+
418
+ # Print summary
419
+ genuine_count = results_df['is_genuine'].sum()
420
+ total_count = len(results_df)
421
+ print(f"\nPrediction Summary:")
422
+ print(f"Total pairs: {total_count}")
423
+ print(f"Genuine predictions: {genuine_count} ({100*genuine_count/total_count:.1f}%)")
424
+ print(f"Forged predictions: {total_count - genuine_count} ({100*(total_count-genuine_count)/total_count:.1f}%)")