|
from utils import CustomDataset, transform, Convert_ONNX |
|
from torch.utils.data import Dataset, DataLoader |
|
from utils import CustomDataset, TestingDataset, transform |
|
from tqdm import tqdm |
|
import torch |
|
import numpy as np |
|
from resnet_model_mask import ResidualBlock, ResNet |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import pickle |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
|
|
torch.manual_seed(1) |
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
num_gpus = torch.cuda.device_count() |
|
print(num_gpus) |
|
|
|
test_data_dir = '/mnt/buf1/pma/frbnn/test_ready' |
|
test_dataset = TestingDataset(test_data_dir, transform=transform) |
|
|
|
num_classes = 2 |
|
testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32) |
|
|
|
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
|
model = nn.DataParallel(model) |
|
model = model.to(device) |
|
params = sum(p.numel() for p in model.parameters()) |
|
print("num params ",params) |
|
|
|
model_1 = 'models_mask/model-43-99.235_42.pt' |
|
|
|
model.load_state_dict(torch.load(model_1, weights_only=True)) |
|
model = model.eval() |
|
|
|
|
|
val_loss = 0.0 |
|
correct_valid = 0 |
|
total = 0 |
|
results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]} |
|
model.eval() |
|
with torch.no_grad(): |
|
for images, labels in tqdm(testloader): |
|
inputs, labels = images.to(device), labels |
|
outputs = model(inputs, return_mask = True) |
|
_, predicted = torch.max(outputs, 1) |
|
results['output'].extend(outputs.cpu().numpy().tolist()) |
|
results['pred'].extend(predicted.cpu().numpy().tolist()) |
|
results['true'].extend(labels[0].cpu().numpy().tolist()) |
|
results['freq'].extend(labels[2].cpu().numpy().tolist()) |
|
results['dm'].extend(labels[1].cpu().numpy().tolist()) |
|
results['snr'].extend(labels[3].cpu().numpy().tolist()) |
|
results['boxcar'].extend(labels[4].cpu().numpy().tolist()) |
|
total += labels[0].size(0) |
|
correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item() |
|
|
|
|
|
val_accuracy = correct_valid / total * 100.0 |
|
print("===========================") |
|
print('accuracy: ', val_accuracy) |
|
print("===========================") |
|
|
|
import pickle |
|
|
|
|
|
with open('models_mask/test_42.pkl', 'wb') as f: |
|
pickle.dump(results, f) |
|
|
|
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix |
|
|
|
|
|
true = results['true'] |
|
pred = results['pred'] |
|
|
|
|
|
precision = precision_score(true, pred) |
|
recall = recall_score(true, pred) |
|
f1 = f1_score(true, pred) |
|
|
|
tn, fp, fn, tp = confusion_matrix(true, pred).ravel() |
|
|
|
|
|
fpr = fp / (fp + tn) |
|
|
|
print(f"False Positive Rate: {fpr:.3f}") |
|
|
|
print(f"Precision: {precision:.3f}") |
|
print(f"Recall: {recall:.3f}") |
|
print(f"F1 Score: {f1:.3f}") |
|
|
|
|
|
|
|
df = pd.DataFrame({ |
|
'dm': results['dm'], |
|
'true': results['true'], |
|
'pred': results['pred'], |
|
'snr': results['snr'], |
|
'freq': results['freq'], |
|
'boxcar': np.array(results['boxcar'])/2 |
|
}) |
|
|
|
|
|
df = df[df['true'] == 1].copy() |
|
|
|
print(f"Filtered to {len(df)} samples with true label = 1") |
|
|
|
|
|
dm_bins = np.linspace(df['dm'].min(), df['dm'].max(), 20) |
|
df['dm_bin'] = pd.cut(df['dm'], bins=dm_bins, include_lowest=True) |
|
print('min boxcar',df['boxcar'].min()) |
|
|
|
def calculate_accuracy_with_uncertainty(group): |
|
correct = (group['true'] == group['pred']).sum() |
|
total = len(group) |
|
accuracy = correct / total * 100 |
|
|
|
p = correct / total |
|
se = np.sqrt(p * (1 - p) / total) * 100 |
|
return pd.Series({'accuracy': accuracy, 'std_error': se, 'n_samples': total}) |
|
|
|
dm_accuracy = df.groupby('dm_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
|
|
|
dm_accuracy['dm_midpoint'] = dm_accuracy['dm_bin'].apply(lambda x: x.mid) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
ax1 = plt.gca() |
|
ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
|
yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
|
ax1.set_ylabel('Accuracy (%)', fontsize=16) |
|
ax1.set_title('Accuracy vs Dispersion Measure', fontsize=18) |
|
ax1.grid(True, alpha=0.3) |
|
ax1.set_ylim(97, 100) |
|
ax1.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax1.get_yticks() |
|
ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
|
|
ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax1.legend(fontsize=14) |
|
|
|
|
|
ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
|
|
|
plt.tight_layout() |
|
plt.savefig('models_mask/accuracy_vs_dm.pdf', dpi=300, bbox_inches='tight') |
|
plt.show() |
|
|
|
|
|
|
|
df_snr_filtered = df[df['snr'] > 0].copy() |
|
|
|
|
|
snr_bins = np.linspace(df_snr_filtered['snr'].min(), df_snr_filtered['snr'].max(), 20) |
|
df_snr_filtered['snr_bin'] = pd.cut(df_snr_filtered['snr'], bins=snr_bins, include_lowest=True) |
|
|
|
|
|
snr_accuracy = df_snr_filtered.groupby('snr_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
|
|
|
snr_accuracy['snr_midpoint'] = snr_accuracy['snr_bin'].apply(lambda x: x.mid) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
ax2 = plt.gca() |
|
ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
|
yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
|
ax2.set_ylabel('Accuracy (%)', fontsize=16) |
|
ax2.set_title('Accuracy vs SNR', fontsize=18) |
|
ax2.grid(True, alpha=0.3) |
|
ax2.set_ylim(80, 100) |
|
ax2.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax2.get_yticks() |
|
ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
|
|
ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax2.legend(fontsize=14) |
|
|
|
|
|
ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
|
|
|
plt.tight_layout() |
|
plt.savefig('models_mask/accuracy_vs_snr.pdf', dpi=300, bbox_inches='tight') |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
df_boxcar_filtered = df[df['boxcar'] > 0].copy() |
|
df_boxcar_filtered['boxcar_bin'] = pd.qcut(df_boxcar_filtered['boxcar'], q=20, duplicates='drop') |
|
|
|
|
|
boxcar_accuracy = df_boxcar_filtered.groupby('boxcar_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
|
|
|
boxcar_accuracy['boxcar_midpoint'] = boxcar_accuracy['boxcar_bin'].apply(lambda x: x.mid) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
ax3 = plt.gca() |
|
ax3.errorbar(boxcar_accuracy['boxcar_midpoint'], boxcar_accuracy['accuracy'], |
|
yerr=boxcar_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax3.set_xscale('log') |
|
ax3.set_xlabel('Boxcar Width (log scale)', fontsize=16) |
|
|
|
ax3.grid(True, alpha=0.3) |
|
ax3.set_ylim(0, 100) |
|
ax3.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax3.get_yticks() |
|
ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
|
|
ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax3.legend(fontsize=14) |
|
|
|
|
|
ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
|
|
|
plt.tight_layout() |
|
plt.savefig('models_mask/accuracy_vs_boxcar.pdf', dpi=300, bbox_inches='tight') |
|
plt.show() |
|
|
|
|
|
print(f"Plots saved to models_mask/accuracy_vs_dm.pdf, models_mask/accuracy_vs_snr.pdf, and models_mask/accuracy_vs_boxcar.pdf") |
|
|
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
|
|
|
ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
|
yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
|
ax1.set_ylabel('Accuracy (%)', fontsize=16) |
|
|
|
ax1.grid(True, alpha=0.3) |
|
ax1.set_ylim(97, 100.5) |
|
ax1.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax1.get_yticks() |
|
ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax1.legend(fontsize=14) |
|
|
|
|
|
ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
|
yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
|
|
|
ax2.grid(True, alpha=0.3) |
|
ax2.set_ylim(88, 100.5) |
|
ax2.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax2.get_yticks() |
|
ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax2.legend(fontsize=14) |
|
|
|
|
|
ax3.errorbar(boxcar_accuracy['boxcar_midpoint'][:-1], |
|
boxcar_accuracy['accuracy'][:-1], |
|
yerr=boxcar_accuracy['std_error'][:-1], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
|
capsize=5, capthick=2, elinewidth=1) |
|
ax3.set_xscale('log') |
|
ax3.set_xlabel('Boxcar Width (log scale) [s]', fontsize=16) |
|
|
|
ax3.grid(True, alpha=0.3) |
|
ax3.set_ylim(96, 100.5) |
|
ax3.tick_params(axis='both', which='major', labelsize=14) |
|
|
|
|
|
yticks = ax3.get_yticks() |
|
ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
|
ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
|
label=f'Overall: {val_accuracy:.2f}%') |
|
ax3.legend(fontsize=14) |
|
|
|
|
|
ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
|
ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
|
ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
|
|
|
plt.tight_layout() |
|
plt.savefig('models_mask/accuracy_vs_all_parameters.pdf', |
|
dpi=300, bbox_inches='tight', |
|
pad_inches=0.1, format='pdf') |
|
plt.show() |
|
|
|
print(f"Combined plot saved to models_mask/accuracy_vs_all_parameters.pdf") |
|
|