|
import os |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
from torchvision import models |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torch.utils.data import Dataset, DataLoader |
|
from typing import Dict, List, Tuple, Optional, Union |
|
from dataclasses import dataclass |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class InferenceConfig: |
|
|
|
model_name: str = "resnet34" |
|
embedding_dim: int = 128 |
|
normalize_embeddings: bool = True |
|
checkpoint_path: str = "../../model/models_checkpoints/best_model.pth" |
|
|
|
|
|
batch_size: int = 32 |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
distance_threshold: float = 0.5 |
|
|
|
|
|
remove_bg: bool = False |
|
num_workers: int = 4 |
|
|
|
|
|
CONFIG = InferenceConfig() |
|
|
|
|
|
|
|
|
|
class ResNetBackbone(nn.Module): |
|
"""ResNet backbone feature extractor.""" |
|
|
|
def __init__(self, model_name: str = "resnet34"): |
|
super().__init__() |
|
|
|
if model_name == "resnet18": |
|
self.resnet = models.resnet18(weights=None) |
|
elif model_name == "resnet34": |
|
self.resnet = models.resnet34(weights=None) |
|
elif model_name == "resnet50": |
|
self.resnet = models.resnet50(weights=None) |
|
else: |
|
raise ValueError(f"Unsupported model_name: {model_name}") |
|
|
|
|
|
self.resnet.fc = nn.Identity() |
|
|
|
|
|
with torch.no_grad(): |
|
dummy = torch.randn(1, 3, 224, 224) |
|
self.output_dim = self.resnet(dummy).shape[1] |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.resnet(x) |
|
|
|
class AdvancedEmbeddingHead(nn.Module): |
|
"""Embedding head to project features to embedding space.""" |
|
|
|
def __init__(self, input_dim: int, embedding_dim: int, dropout: float = 0.5): |
|
super().__init__() |
|
|
|
self.input_dim = input_dim |
|
self.embedding_dim = embedding_dim |
|
|
|
if input_dim > embedding_dim * 4: |
|
hidden_dim = max(embedding_dim * 2, input_dim // 4) |
|
self.layers = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
|
|
nn.Linear(hidden_dim, embedding_dim * 2), |
|
nn.LayerNorm(embedding_dim * 2), |
|
nn.GELU(), |
|
nn.Dropout(dropout / 2), |
|
|
|
nn.Linear(embedding_dim * 2, embedding_dim), |
|
nn.LayerNorm(embedding_dim) |
|
) |
|
else: |
|
self.layers = nn.Sequential( |
|
nn.Linear(input_dim, embedding_dim), |
|
nn.LayerNorm(embedding_dim) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.flatten(1) |
|
return self.layers(x) |
|
|
|
class SiameseSignatureNetwork(nn.Module): |
|
"""Siamese network for signature verification.""" |
|
|
|
def __init__(self, config: InferenceConfig = CONFIG): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
self.backbone = ResNetBackbone(model_name=config.model_name) |
|
backbone_dim = self.backbone.output_dim |
|
|
|
|
|
self.embedding_head = AdvancedEmbeddingHead( |
|
input_dim=backbone_dim, |
|
embedding_dim=config.embedding_dim, |
|
dropout=0.0 |
|
) |
|
|
|
self.normalize_embeddings = config.normalize_embeddings |
|
self.distance_threshold = config.distance_threshold |
|
|
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Forward pass for inference.""" |
|
|
|
f1 = self.backbone(img1) |
|
f2 = self.backbone(img2) |
|
|
|
|
|
emb1 = self.embedding_head(f1) |
|
emb2 = self.embedding_head(f2) |
|
|
|
|
|
if self.normalize_embeddings: |
|
emb1 = F.normalize(emb1, p=2, dim=1) |
|
emb2 = F.normalize(emb2, p=2, dim=1) |
|
|
|
return emb1, emb2 |
|
|
|
def predict_pair(self, img1: torch.Tensor, img2: torch.Tensor, |
|
threshold: Optional[float] = None) -> Dict[str, torch.Tensor]: |
|
"""Predict similarity between image pairs.""" |
|
self.eval() |
|
with torch.no_grad(): |
|
emb1, emb2 = self(img1, img2) |
|
distances = F.pairwise_distance(emb1, emb2) |
|
|
|
thresh = threshold if threshold is not None else self.distance_threshold |
|
predictions = (distances < thresh).long() |
|
|
|
|
|
similarities = 1.0 / (1.0 + distances) |
|
|
|
return { |
|
'predictions': predictions, |
|
'distances': distances, |
|
'similarities': similarities, |
|
'threshold': torch.tensor(thresh) |
|
} |
|
|
|
|
|
|
|
|
|
class PredictionDataset(Dataset): |
|
"""Dataset for batch prediction from Excel.""" |
|
|
|
def __init__(self, excel_path: str, image_folder: str, config: InferenceConfig = CONFIG): |
|
self.image_folder = image_folder |
|
self.config = config |
|
self.data = pd.read_excel(excel_path) |
|
self.transform = self._get_transforms() |
|
|
|
|
|
required_cols = ['image_1_path', 'image_2_path'] |
|
missing_cols = [col for col in required_cols if col not in self.data.columns] |
|
if missing_cols: |
|
raise ValueError(f"Missing required columns: {missing_cols}") |
|
|
|
def _get_transforms(self) -> transforms.Compose: |
|
"""Get image transforms for inference.""" |
|
return transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]: |
|
"""Return image pair and index.""" |
|
row = self.data.iloc[idx] |
|
|
|
img1 = self._load_image(row['image_1_path']) |
|
img2 = self._load_image(row['image_2_path']) |
|
|
|
return img1, img2, idx |
|
|
|
def _load_image(self, image_path: str) -> torch.Tensor: |
|
"""Load and transform image.""" |
|
image = replace_background_with_white( |
|
image_path, self.image_folder, |
|
remove_bg=self.config.remove_bg |
|
) |
|
return self.transform(image) |
|
|
|
|
|
|
|
|
|
def estimate_background_color_pil(image: Image.Image, border_width: int = 10, |
|
method: str = "median") -> np.ndarray: |
|
"""Estimate background color from image borders.""" |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
np_img = np.array(image) |
|
h, w, _ = np_img.shape |
|
|
|
|
|
top = np_img[:border_width, :, :].reshape(-1, 3) |
|
bottom = np_img[-border_width:, :, :].reshape(-1, 3) |
|
left = np_img[:, :border_width, :].reshape(-1, 3) |
|
right = np_img[:, -border_width:, :].reshape(-1, 3) |
|
|
|
all_border_pixels = np.concatenate([top, bottom, left, right], axis=0) |
|
|
|
if method == "mean": |
|
return np.mean(all_border_pixels, axis=0).astype(np.uint8) |
|
else: |
|
return np.median(all_border_pixels, axis=0).astype(np.uint8) |
|
|
|
def replace_background_with_white(image_name: str, folder_img: str, |
|
tolerance: int = 40, method: str = "median", |
|
remove_bg: bool = False) -> Image.Image: |
|
"""Replace background with white based on border color estimation.""" |
|
image_path = os.path.join(folder_img, image_name) |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
if not remove_bg: |
|
return image |
|
|
|
np_img = np.array(image) |
|
bg_color = estimate_background_color_pil(image, method=method) |
|
|
|
|
|
diff = np.abs(np_img.astype(np.int32) - bg_color.astype(np.int32)) |
|
mask = np.all(diff < tolerance, axis=2) |
|
|
|
|
|
result = np_img.copy() |
|
result[mask] = [255, 255, 255] |
|
|
|
return Image.fromarray(result) |
|
|
|
|
|
|
|
|
|
class SignatureVerifier: |
|
"""Main class for signature verification predictions.""" |
|
|
|
def __init__(self, config: InferenceConfig = CONFIG): |
|
self.config = config |
|
self.device = torch.device(config.device) |
|
self.model = self._load_model() |
|
self.transform = self._get_transforms() |
|
|
|
def _get_transforms(self) -> transforms.Compose: |
|
"""Get image transforms.""" |
|
return transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
def _load_model(self) -> SiameseSignatureNetwork: |
|
"""Load model from checkpoint.""" |
|
print(f"Loading model from: {self.config.checkpoint_path}") |
|
|
|
|
|
model = SiameseSignatureNetwork(self.config) |
|
|
|
|
|
checkpoint = torch.load(self.config.checkpoint_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
else: |
|
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
if 'prediction_threshold' in checkpoint: |
|
model.distance_threshold = checkpoint['prediction_threshold'] |
|
print(f"Loaded threshold: {model.distance_threshold:.4f}") |
|
|
|
|
|
if 'best_eer' in checkpoint: |
|
print(f"Model best EER: {checkpoint['best_eer']:.4f}") |
|
|
|
model = model.to(self.device) |
|
model.eval() |
|
|
|
print("Model loaded successfully!") |
|
return model |
|
|
|
def predict_single_pair(self, image1_path: str, image2_path: str, |
|
image_folder: str = "") -> Dict[str, float]: |
|
"""Predict similarity for a single pair of images.""" |
|
|
|
img1 = replace_background_with_white( |
|
image1_path, image_folder, remove_bg=self.config.remove_bg |
|
) |
|
img2 = replace_background_with_white( |
|
image2_path, image_folder, remove_bg=self.config.remove_bg |
|
) |
|
|
|
|
|
img1_tensor = self.transform(img1).unsqueeze(0).to(self.device) |
|
img2_tensor = self.transform(img2).unsqueeze(0).to(self.device) |
|
|
|
|
|
results = self.model.predict_pair(img1_tensor, img2_tensor) |
|
|
|
return { |
|
'is_genuine': bool(results['predictions'].item()), |
|
'distance': float(results['distances'].item()), |
|
'similarity_score': float(results['similarities'].item()), |
|
'threshold': float(results['threshold'].item()) |
|
} |
|
|
|
def predict_from_excel(self, excel_path: str, image_folder: str, |
|
output_path: Optional[str] = None) -> pd.DataFrame: |
|
"""Batch prediction from Excel file.""" |
|
|
|
dataset = PredictionDataset(excel_path, image_folder, self.config) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=self.config.batch_size, |
|
shuffle=False, |
|
num_workers=self.config.num_workers, |
|
pin_memory=True |
|
) |
|
|
|
|
|
all_predictions = [] |
|
all_distances = [] |
|
all_similarities = [] |
|
|
|
|
|
print(f"Processing {len(dataset)} pairs...") |
|
with torch.no_grad(): |
|
for img1_batch, img2_batch, indices in tqdm(dataloader): |
|
img1_batch = img1_batch.to(self.device) |
|
img2_batch = img2_batch.to(self.device) |
|
|
|
results = self.model.predict_pair(img1_batch, img2_batch) |
|
|
|
all_predictions.extend(results['predictions'].cpu().numpy()) |
|
all_distances.extend(results['distances'].cpu().numpy()) |
|
all_similarities.extend(results['similarities'].cpu().numpy()) |
|
|
|
|
|
results_df = dataset.data.copy() |
|
results_df['prediction'] = all_predictions |
|
results_df['is_genuine'] = results_df['prediction'].astype(bool) |
|
results_df['distance'] = all_distances |
|
results_df['similarity_score'] = all_similarities |
|
results_df['threshold'] = self.model.distance_threshold |
|
|
|
|
|
if output_path: |
|
results_df.to_excel(output_path, index=False) |
|
print(f"Results saved to: {output_path}") |
|
|
|
return results_df |
|
|
|
def update_threshold(self, new_threshold: float): |
|
"""Update the decision threshold.""" |
|
self.model.distance_threshold = new_threshold |
|
print(f"Threshold updated to: {new_threshold:.4f}") |
|
|
|
|
|
config = InferenceConfig( |
|
checkpoint_path="../../../../model/models_checkpoints/fa7e1bdc01814016ac8220bfbf1eb691/best_model.pth", |
|
batch_size=32, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
verifier = SignatureVerifier(config) |
|
|
|
''' |
|
# Example 1: Single pair prediction |
|
print("\n--- Single Pair Prediction ---") |
|
result = verifier.predict_single_pair( |
|
image1_path="sig1.png", |
|
image2_path="sig2.png", |
|
image_folder="../../data/classify/preprared_data/images/" |
|
) |
|
''' |
|
|
|
|
|
print("\n--- Batch Prediction from Excel ---") |
|
results_df = verifier.predict_from_excel( |
|
excel_path="../../../../data/classify/preprared_data/labels/test_pairs_balanced_v12.xlsx", |
|
image_folder="../../../../data/classify/preprared_data/images/", |
|
output_path="./predictions_output.xlsx" |
|
) |
|
|
|
|
|
genuine_count = results_df['is_genuine'].sum() |
|
total_count = len(results_df) |
|
print(f"\nPrediction Summary:") |
|
print(f"Total pairs: {total_count}") |
|
print(f"Genuine predictions: {genuine_count} ({100*genuine_count/total_count:.1f}%)") |
|
print(f"Forged predictions: {total_count - genuine_count} ({100*(total_count-genuine_count)/total_count:.1f}%)") |