from utils import CustomDataset, transform, Convert_ONNX from torch.utils.data import Dataset, DataLoader from utils import CustomDataset, TestingDataset, transform from tqdm import tqdm import torch import numpy as np from resnet_model_mask import ResidualBlock, ResNet import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm import torch.nn.functional as F from torch.optim.lr_scheduler import ReduceLROnPlateau import pickle import matplotlib.pyplot as plt import pandas as pd torch.manual_seed(1) # torch.manual_seed(42) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_gpus = torch.cuda.device_count() print(num_gpus) test_data_dir = '/mnt/buf1/pma/frbnn/test_ready' test_dataset = TestingDataset(test_data_dir, transform=transform) num_classes = 2 testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32) model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) model = nn.DataParallel(model) model = model.to(device) params = sum(p.numel() for p in model.parameters()) print("num params ",params) model_1 = 'models_mask/model-43-99.235_42.pt' # model_1 ='models/model-47-99.125.pt' model.load_state_dict(torch.load(model_1, weights_only=True)) model = model.eval() # eval val_loss = 0.0 correct_valid = 0 total = 0 results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]} model.eval() with torch.no_grad(): for images, labels in tqdm(testloader): inputs, labels = images.to(device), labels outputs = model(inputs, return_mask = True) _, predicted = torch.max(outputs, 1) results['output'].extend(outputs.cpu().numpy().tolist()) results['pred'].extend(predicted.cpu().numpy().tolist()) results['true'].extend(labels[0].cpu().numpy().tolist()) results['freq'].extend(labels[2].cpu().numpy().tolist()) results['dm'].extend(labels[1].cpu().numpy().tolist()) results['snr'].extend(labels[3].cpu().numpy().tolist()) results['boxcar'].extend(labels[4].cpu().numpy().tolist()) total += labels[0].size(0) correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item() # Calculate training accuracy after each epoch val_accuracy = correct_valid / total * 100.0 print("===========================") print('accuracy: ', val_accuracy) print("===========================") import pickle # Pickle the dictionary to a file with open('models_mask/test_42.pkl', 'wb') as f: pickle.dump(results, f) from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix # Example binary labels true = results['true'] # ground truth pred = results['pred'] # predicted # Compute metrics precision = precision_score(true, pred) recall = recall_score(true, pred) f1 = f1_score(true, pred) # Get confusion matrix: TN, FP, FN, TP tn, fp, fn, tp = confusion_matrix(true, pred).ravel() # Compute FPR fpr = fp / (fp + tn) print(f"False Positive Rate: {fpr:.3f}") print(f"Precision: {precision:.3f}") print(f"Recall: {recall:.3f}") print(f"F1 Score: {f1:.3f}") # Plot accuracy as a function of DM # Create a DataFrame for easier manipulation df = pd.DataFrame({ 'dm': results['dm'], 'true': results['true'], 'pred': results['pred'], 'snr': results['snr'], 'freq': results['freq'], 'boxcar': np.array(results['boxcar'])/2 }) # Filter to only include positive class samples (true == 1) df = df[df['true'] == 1].copy() print(f"Filtered to {len(df)} samples with true label = 1") # Create DM bins for grouping dm_bins = np.linspace(df['dm'].min(), df['dm'].max(), 20) # 20 bins df['dm_bin'] = pd.cut(df['dm'], bins=dm_bins, include_lowest=True) print('min boxcar',df['boxcar'].min()) # Calculate accuracy and uncertainty for each DM bin def calculate_accuracy_with_uncertainty(group): correct = (group['true'] == group['pred']).sum() total = len(group) accuracy = correct / total * 100 # Standard error for binomial proportion p = correct / total se = np.sqrt(p * (1 - p) / total) * 100 # Convert to percentage return pd.Series({'accuracy': accuracy, 'std_error': se, 'n_samples': total}) dm_accuracy = df.groupby('dm_bin').apply(calculate_accuracy_with_uncertainty).reset_index() # Get the midpoint of each bin for plotting dm_accuracy['dm_midpoint'] = dm_accuracy['dm_bin'].apply(lambda x: x.mid) # Create the plot with error bars plt.figure(figsize=(10, 6)) ax1 = plt.gca() ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) ax1.set_ylabel('Accuracy (%)', fontsize=16) ax1.set_title('Accuracy vs Dispersion Measure', fontsize=18) ax1.grid(True, alpha=0.3) ax1.set_ylim(97, 100) ax1.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax1.get_yticks() ax1.set_yticks([tick for tick in yticks if tick <= 100]) # Add some statistics to the plot ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax1.legend(fontsize=14) # Add subplot labels at the bottom ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') plt.tight_layout() plt.savefig('models_mask/accuracy_vs_dm.pdf', dpi=300, bbox_inches='tight') plt.show() # Plot accuracy as a function of SNR # Filter out zero/negative SNR values (not physically meaningful) df_snr_filtered = df[df['snr'] > 0].copy() # Create SNR bins for grouping snr_bins = np.linspace(df_snr_filtered['snr'].min(), df_snr_filtered['snr'].max(), 20) # 20 bins df_snr_filtered['snr_bin'] = pd.cut(df_snr_filtered['snr'], bins=snr_bins, include_lowest=True) # Calculate accuracy and uncertainty for each SNR bin snr_accuracy = df_snr_filtered.groupby('snr_bin').apply(calculate_accuracy_with_uncertainty).reset_index() # Get the midpoint of each bin for plotting snr_accuracy['snr_midpoint'] = snr_accuracy['snr_bin'].apply(lambda x: x.mid) # Create the SNR plot with error bars plt.figure(figsize=(10, 6)) ax2 = plt.gca() ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) ax2.set_ylabel('Accuracy (%)', fontsize=16) ax2.set_title('Accuracy vs SNR', fontsize=18) ax2.grid(True, alpha=0.3) ax2.set_ylim(80, 100) ax2.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax2.get_yticks() ax2.set_yticks([tick for tick in yticks if tick <= 100]) # Add overall accuracy reference line ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax2.legend(fontsize=14) # Add subplot labels at the bottom ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') plt.tight_layout() plt.savefig('models_mask/accuracy_vs_snr.pdf', dpi=300, bbox_inches='tight') plt.show() # Plot accuracy as a function of boxcar # Create boxcar bins for grouping # Use quantile-based binning to ensure all bins have data # Filter out zero/negative values first for meaningful analysis df_boxcar_filtered = df[df['boxcar'] > 0].copy() df_boxcar_filtered['boxcar_bin'] = pd.qcut(df_boxcar_filtered['boxcar'], q=20, duplicates='drop') # Calculate accuracy and uncertainty for each boxcar bin boxcar_accuracy = df_boxcar_filtered.groupby('boxcar_bin').apply(calculate_accuracy_with_uncertainty).reset_index() # Get the midpoint of each bin for plotting boxcar_accuracy['boxcar_midpoint'] = boxcar_accuracy['boxcar_bin'].apply(lambda x: x.mid) # Create the boxcar plot with error bars plt.figure(figsize=(10, 6)) ax3 = plt.gca() ax3.errorbar(boxcar_accuracy['boxcar_midpoint'], boxcar_accuracy['accuracy'], yerr=boxcar_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax3.set_xscale('log') ax3.set_xlabel('Boxcar Width (log scale)', fontsize=16) # ax3.set_title('Accuracy vs Boxcar Width (Log Scale)', fontsize=18) ax3.grid(True, alpha=0.3) ax3.set_ylim(0, 100) ax3.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax3.get_yticks() ax3.set_yticks([tick for tick in yticks if tick <= 100]) # Add overall accuracy reference line ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax3.legend(fontsize=14) # Add subplot labels at the bottom ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') plt.tight_layout() plt.savefig('models_mask/accuracy_vs_boxcar.pdf', dpi=300, bbox_inches='tight') plt.show() print(f"Plots saved to models_mask/accuracy_vs_dm.pdf, models_mask/accuracy_vs_snr.pdf, and models_mask/accuracy_vs_boxcar.pdf") # Create combined plot with all three parameters fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) # DM plot with error bars ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) ax1.set_ylabel('Accuracy (%)', fontsize=16) # ax1.set_title('Accuracy vs DM', fontsize=18) ax1.grid(True, alpha=0.3) ax1.set_ylim(97, 100.5) ax1.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax1.get_yticks() ax1.set_yticks([tick for tick in yticks if tick <= 100]) ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax1.legend(fontsize=14) # SNR plot with error bars ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) # ax2.set_title('Accuracy vs SNR', fontsize=18) ax2.grid(True, alpha=0.3) ax2.set_ylim(88, 100.5) ax2.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax2.get_yticks() ax2.set_yticks([tick for tick in yticks if tick <= 100]) ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax2.legend(fontsize=14) # Boxcar plot (log scale) with error bars ax3.errorbar(boxcar_accuracy['boxcar_midpoint'][:-1], boxcar_accuracy['accuracy'][:-1], yerr=boxcar_accuracy['std_error'][:-1], fmt='o-', color='#b80707', linewidth=2, markersize=6, capsize=5, capthick=2, elinewidth=1) ax3.set_xscale('log') ax3.set_xlabel('Boxcar Width (log scale) [s]', fontsize=16) # ax3.set_title('Accuracy vs Boxcar Width', fontsize=18) ax3.grid(True, alpha=0.3) ax3.set_ylim(96, 100.5) ax3.tick_params(axis='both', which='major', labelsize=14) # Remove y-axis ticks over 100 yticks = ax3.get_yticks() ax3.set_yticks([tick for tick in yticks if tick <= 100]) ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, label=f'Overall: {val_accuracy:.2f}%') ax3.legend(fontsize=14) # Add subplot labels at the bottom ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') plt.tight_layout() plt.savefig('models_mask/accuracy_vs_all_parameters.pdf', dpi=300, bbox_inches='tight', pad_inches=0.1, format='pdf') plt.show() print(f"Combined plot saved to models_mask/accuracy_vs_all_parameters.pdf")