from utils import CustomDataset, transform, Convert_ONNX from torch.utils.data import Dataset, DataLoader import torch import numpy as np from resnet_model_mask import ResidualBlock, ResNet import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm import torch.nn.functional as F from torch.optim.lr_scheduler import ReduceLROnPlateau import pickle torch.manual_seed(1) # torch.manual_seed(42) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_gpus = torch.cuda.device_count() print(num_gpus) # Create custom dataset instance data_dir = '/mnt/buf0/pma/frbnn/train_ready' dataset = CustomDataset(data_dir, transform=transform) valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready' valid_dataset = CustomDataset(valid_data_dir, transform=transform) num_classes = 2 trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) model = nn.DataParallel(model) model = model.to(device) params = sum(p.numel() for p in model.parameters()) print("num params ",params) model_path = 'models/model-47-99.125.pt' model.load_state_dict(torch.load(model_path, weights_only=True)) model = model.eval() # Collect all plotting data import sigpyproc.readers as r import cv2 import numpy as np import matplotlib.pyplot as plt from scipy.special import softmax from tqdm import tqdm all_detections = [] # first file snr 180 print("Processing first file (SNR 180)...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60398_67123_110077819_frb20240114a_0001/LoC.C0736/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(27085468,27397968, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9982: key = data.cpu().numpy() all_detections.append({ 'data': key, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60398_67123_110077819_frb20240114a_0001 (SNR 180)', 'normalization': 'raw', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 1 detections: {num_pos}") # second file snr 60 print("Processing second file (SNR 60)...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60428_58167_24730285_frb20240114a_0001/LoC.C1504/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(8148984,8461484, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9988: key = data.cpu().numpy() result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) all_detections.append({ 'data': key/result, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60428_58167_24730285_frb20240114a_0001 (SNR 60)', 'normalization': 'normalized', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 2 detections: {num_pos}") # third file print("Processing third file...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60427_42703_18513000_frb20240114a_0001/LoC.C1504/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(20343125,20655625, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9988: key = data.cpu().numpy() result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) all_detections.append({ 'data': key/result, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60427_42703_18513000_frb20240114a_0001', 'normalization': 'normalized', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 3 detections: {num_pos}") # fourth file print("Processing fourth file...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60395_72956_94613525_frb20240114a_0001/LoB.C1312/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(8708515,9021015, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9988: key = data.cpu().numpy() result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) all_detections.append({ 'data': key/result, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60395_72956_94613525_frb20240114a_0001', 'normalization': 'normalized', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 4 detections: {num_pos}") # fifth file print("Processing fifth file...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60429_47342_29343017_frb20240114a_0001/LoB.C1120/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(10399062,10711562, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9988: key = data.cpu().numpy() result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) all_detections.append({ 'data': key/result, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60429_47342_29343017_frb20240114a_0001', 'normalization': 'normalized', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 5 detections: {num_pos}") # sixth file print("Processing sixth file...") fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60456_42557_118616821_frb20240114a_0001/LoC.C1312/decimated.fil') header = fil.header print(header) triggers = [] counter = 0 for i in tqdm(range(1250000,1562500, 2048)): data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() out = model(transform(torch.tensor(data).cuda())[None]) out = softmax(out.detach().cpu().numpy(), axis=1) triggers.append(out) counter += 1 if out[0, 1]>0.9988: key = data.cpu().numpy() result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) all_detections.append({ 'data': key/result, 'confidence': out[0, 1], 'file_index': i, 'file_name': 'fil_60456_42557_118616821_frb20240114a_0001', 'normalization': 'normalized', 'header': header }) stack = np.stack(triggers) positives = stack[:,0,1] num_pos = np.where(positives > 0.9988)[0].shape[0] print(f"File 6 detections: {num_pos}") # Create combined plot print(f"\nTotal detections found: {len(all_detections)}") if len(all_detections) > 0: # Sort detections by confidence (highest first) all_detections.sort(key=lambda x: x['confidence'], reverse=True) # Create subplots n_detections = len(all_detections) cols = 2 # Fixed 2 columns rows = 5 # Fixed 5 rows fig, axes = plt.subplots(rows, cols, figsize=(10, 12)) # Flatten axes array to make indexing easier axes_flat = axes.flatten() for idx, detection in enumerate(all_detections): ax = axes_flat[idx] # Calculate median for better contrast data_median = np.median(detection['data']) im = ax.imshow(detection['data'], aspect=6, cmap='hot', vmin=data_median) # Set proper time axis ticks # Each sample is 6.5e-5 seconds, and we have 2048 samples time_increment = 6.5e-5 # seconds per sample n_samples = detection['data'].shape[1] # should be 2048 total_time = n_samples * time_increment # total time span # Create time ticks at reasonable intervals n_ticks = 5 # number of ticks we want tick_positions = np.linspace(0, n_samples-1, n_ticks) tick_labels = [f"{i*time_increment:.2f}" for i in tick_positions] ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, fontsize=12) # Only add x-axis label for bottom row (row 4 in 0-indexed 5 rows) if idx >= 8: # Bottom row in 2x5 grid (indices 8 and 9) ax.set_xlabel('Time (seconds)', fontsize=14) # Set proper frequency axis ticks using header information header = detection['header'] fch1 = header.fch1 # frequency of first channel in MHz foff = header.foff # frequency offset between channels in MHz nchans = header.nchans # number of channels # Calculate frequency range freq_start = fch1 freq_end = fch1 + (nchans - 1) * foff # Create exactly 5 frequency ticks evenly spaced n_freq_ticks = 5 freq_tick_positions = np.linspace(0, nchans-1, n_freq_ticks) freq_values = [fch1 + i * foff for i in freq_tick_positions] freq_labels = [f"{freq:.1f}" for freq in freq_values] ax.set_yticks(freq_tick_positions) ax.set_yticklabels(freq_labels, fontsize=12) # Only add y-axis label for first column (left column) if idx % 2 == 0: # First column in 2x5 grid (indices 0, 2, 4, 6, 8) ax.set_ylabel('Freq. (MHz)', fontsize=14) # Make tick markers smaller ax.tick_params(axis='both', which='major', size=3) # Hide empty subplots for idx in range(n_detections, len(axes_flat)): axes_flat[idx].set_visible(False) # Reduce whitespace between plots plt.subplots_adjust(hspace=0.3, wspace=0.2) plt.savefig('combined_frb_detections.pdf', dpi=150, bbox_inches='tight', format='pdf') plt.show() print(f"Combined plot saved as 'combined_frb_detections.png'") # Print summary print("\nDetection Summary:") for i, detection in enumerate(all_detections): print(f"{i+1}. {detection['file_name'][:50]}... - Confidence: {detection['confidence']:.4f}") else: print("No detections found across all files.")