BLADE_FRBNN / models /plot_params.py
peterma02's picture
Upload folder using huggingface_hub
f3972ea verified
from utils import CustomDataset, transform, Convert_ONNX
from torch.utils.data import Dataset, DataLoader
from utils import CustomDataset, TestingDataset, transform
from tqdm import tqdm
import torch
import numpy as np
from resnet_model_mask import ResidualBlock, ResNet
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pickle
import matplotlib.pyplot as plt
import pandas as pd
torch.manual_seed(1)
# torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(num_gpus)
test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'
test_dataset = TestingDataset(test_data_dir, transform=transform)
num_classes = 2
testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
model = nn.DataParallel(model)
model = model.to(device)
params = sum(p.numel() for p in model.parameters())
print("num params ",params)
model_1 = 'models_mask/model-43-99.235_42.pt'
# model_1 ='models/model-47-99.125.pt'
model.load_state_dict(torch.load(model_1, weights_only=True))
model = model.eval()
# eval
val_loss = 0.0
correct_valid = 0
total = 0
results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}
model.eval()
with torch.no_grad():
for images, labels in tqdm(testloader):
inputs, labels = images.to(device), labels
outputs = model(inputs, return_mask = True)
_, predicted = torch.max(outputs, 1)
results['output'].extend(outputs.cpu().numpy().tolist())
results['pred'].extend(predicted.cpu().numpy().tolist())
results['true'].extend(labels[0].cpu().numpy().tolist())
results['freq'].extend(labels[2].cpu().numpy().tolist())
results['dm'].extend(labels[1].cpu().numpy().tolist())
results['snr'].extend(labels[3].cpu().numpy().tolist())
results['boxcar'].extend(labels[4].cpu().numpy().tolist())
total += labels[0].size(0)
correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()
# Calculate training accuracy after each epoch
val_accuracy = correct_valid / total * 100.0
print("===========================")
print('accuracy: ', val_accuracy)
print("===========================")
import pickle
# Pickle the dictionary to a file
with open('models_mask/test_42.pkl', 'wb') as f:
pickle.dump(results, f)
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
# Example binary labels
true = results['true'] # ground truth
pred = results['pred'] # predicted
# Compute metrics
precision = precision_score(true, pred)
recall = recall_score(true, pred)
f1 = f1_score(true, pred)
# Get confusion matrix: TN, FP, FN, TP
tn, fp, fn, tp = confusion_matrix(true, pred).ravel()
# Compute FPR
fpr = fp / (fp + tn)
print(f"False Positive Rate: {fpr:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1 Score: {f1:.3f}")
# Plot accuracy as a function of DM
# Create a DataFrame for easier manipulation
df = pd.DataFrame({
'dm': results['dm'],
'true': results['true'],
'pred': results['pred'],
'snr': results['snr'],
'freq': results['freq'],
'boxcar': np.array(results['boxcar'])/2
})
# Filter to only include positive class samples (true == 1)
df = df[df['true'] == 1].copy()
print(f"Filtered to {len(df)} samples with true label = 1")
# Create DM bins for grouping
dm_bins = np.linspace(df['dm'].min(), df['dm'].max(), 20) # 20 bins
df['dm_bin'] = pd.cut(df['dm'], bins=dm_bins, include_lowest=True)
print('min boxcar',df['boxcar'].min())
# Calculate accuracy and uncertainty for each DM bin
def calculate_accuracy_with_uncertainty(group):
correct = (group['true'] == group['pred']).sum()
total = len(group)
accuracy = correct / total * 100
# Standard error for binomial proportion
p = correct / total
se = np.sqrt(p * (1 - p) / total) * 100 # Convert to percentage
return pd.Series({'accuracy': accuracy, 'std_error': se, 'n_samples': total})
dm_accuracy = df.groupby('dm_bin').apply(calculate_accuracy_with_uncertainty).reset_index()
# Get the midpoint of each bin for plotting
dm_accuracy['dm_midpoint'] = dm_accuracy['dm_bin'].apply(lambda x: x.mid)
# Create the plot with error bars
plt.figure(figsize=(10, 6))
ax1 = plt.gca()
ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'],
yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16)
ax1.set_ylabel('Accuracy (%)', fontsize=16)
ax1.set_title('Accuracy vs Dispersion Measure', fontsize=18)
ax1.grid(True, alpha=0.3)
ax1.set_ylim(97, 100)
ax1.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax1.get_yticks()
ax1.set_yticks([tick for tick in yticks if tick <= 100])
# Add some statistics to the plot
ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax1.legend(fontsize=14)
# Add subplot labels at the bottom
ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('models_mask/accuracy_vs_dm.pdf', dpi=300, bbox_inches='tight')
plt.show()
# Plot accuracy as a function of SNR
# Filter out zero/negative SNR values (not physically meaningful)
df_snr_filtered = df[df['snr'] > 0].copy()
# Create SNR bins for grouping
snr_bins = np.linspace(df_snr_filtered['snr'].min(), df_snr_filtered['snr'].max(), 20) # 20 bins
df_snr_filtered['snr_bin'] = pd.cut(df_snr_filtered['snr'], bins=snr_bins, include_lowest=True)
# Calculate accuracy and uncertainty for each SNR bin
snr_accuracy = df_snr_filtered.groupby('snr_bin').apply(calculate_accuracy_with_uncertainty).reset_index()
# Get the midpoint of each bin for plotting
snr_accuracy['snr_midpoint'] = snr_accuracy['snr_bin'].apply(lambda x: x.mid)
# Create the SNR plot with error bars
plt.figure(figsize=(10, 6))
ax2 = plt.gca()
ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'],
yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16)
ax2.set_ylabel('Accuracy (%)', fontsize=16)
ax2.set_title('Accuracy vs SNR', fontsize=18)
ax2.grid(True, alpha=0.3)
ax2.set_ylim(80, 100)
ax2.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax2.get_yticks()
ax2.set_yticks([tick for tick in yticks if tick <= 100])
# Add overall accuracy reference line
ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax2.legend(fontsize=14)
# Add subplot labels at the bottom
ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('models_mask/accuracy_vs_snr.pdf', dpi=300, bbox_inches='tight')
plt.show()
# Plot accuracy as a function of boxcar
# Create boxcar bins for grouping
# Use quantile-based binning to ensure all bins have data
# Filter out zero/negative values first for meaningful analysis
df_boxcar_filtered = df[df['boxcar'] > 0].copy()
df_boxcar_filtered['boxcar_bin'] = pd.qcut(df_boxcar_filtered['boxcar'], q=20, duplicates='drop')
# Calculate accuracy and uncertainty for each boxcar bin
boxcar_accuracy = df_boxcar_filtered.groupby('boxcar_bin').apply(calculate_accuracy_with_uncertainty).reset_index()
# Get the midpoint of each bin for plotting
boxcar_accuracy['boxcar_midpoint'] = boxcar_accuracy['boxcar_bin'].apply(lambda x: x.mid)
# Create the boxcar plot with error bars
plt.figure(figsize=(10, 6))
ax3 = plt.gca()
ax3.errorbar(boxcar_accuracy['boxcar_midpoint'], boxcar_accuracy['accuracy'],
yerr=boxcar_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax3.set_xscale('log')
ax3.set_xlabel('Boxcar Width (log scale)', fontsize=16)
# ax3.set_title('Accuracy vs Boxcar Width (Log Scale)', fontsize=18)
ax3.grid(True, alpha=0.3)
ax3.set_ylim(0, 100)
ax3.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax3.get_yticks()
ax3.set_yticks([tick for tick in yticks if tick <= 100])
# Add overall accuracy reference line
ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax3.legend(fontsize=14)
# Add subplot labels at the bottom
ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('models_mask/accuracy_vs_boxcar.pdf', dpi=300, bbox_inches='tight')
plt.show()
print(f"Plots saved to models_mask/accuracy_vs_dm.pdf, models_mask/accuracy_vs_snr.pdf, and models_mask/accuracy_vs_boxcar.pdf")
# Create combined plot with all three parameters
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
# DM plot with error bars
ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'],
yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16)
ax1.set_ylabel('Accuracy (%)', fontsize=16)
# ax1.set_title('Accuracy vs DM', fontsize=18)
ax1.grid(True, alpha=0.3)
ax1.set_ylim(97, 100.5)
ax1.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax1.get_yticks()
ax1.set_yticks([tick for tick in yticks if tick <= 100])
ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax1.legend(fontsize=14)
# SNR plot with error bars
ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'],
yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16)
# ax2.set_title('Accuracy vs SNR', fontsize=18)
ax2.grid(True, alpha=0.3)
ax2.set_ylim(88, 100.5)
ax2.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax2.get_yticks()
ax2.set_yticks([tick for tick in yticks if tick <= 100])
ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax2.legend(fontsize=14)
# Boxcar plot (log scale) with error bars
ax3.errorbar(boxcar_accuracy['boxcar_midpoint'][:-1],
boxcar_accuracy['accuracy'][:-1],
yerr=boxcar_accuracy['std_error'][:-1], fmt='o-', color='#b80707', linewidth=2, markersize=6,
capsize=5, capthick=2, elinewidth=1)
ax3.set_xscale('log')
ax3.set_xlabel('Boxcar Width (log scale) [s]', fontsize=16)
# ax3.set_title('Accuracy vs Boxcar Width', fontsize=18)
ax3.grid(True, alpha=0.3)
ax3.set_ylim(96, 100.5)
ax3.tick_params(axis='both', which='major', labelsize=14)
# Remove y-axis ticks over 100
yticks = ax3.get_yticks()
ax3.set_yticks([tick for tick in yticks if tick <= 100])
ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7,
label=f'Overall: {val_accuracy:.2f}%')
ax3.legend(fontsize=14)
# Add subplot labels at the bottom
ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold')
ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold')
ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('models_mask/accuracy_vs_all_parameters.pdf',
dpi=300, bbox_inches='tight',
pad_inches=0.1, format='pdf')
plt.show()
print(f"Combined plot saved to models_mask/accuracy_vs_all_parameters.pdf")