|
from utils import CustomDataset, transform, Convert_ONNX |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import numpy as np |
|
from resnet_model_mask import ResidualBlock, ResNet |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import pickle |
|
|
|
torch.manual_seed(1) |
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
num_gpus = torch.cuda.device_count() |
|
print(num_gpus) |
|
|
|
|
|
data_dir = '/mnt/buf0/pma/frbnn/train_ready' |
|
dataset = CustomDataset(data_dir, transform=transform) |
|
valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready' |
|
valid_dataset = CustomDataset(valid_data_dir, transform=transform) |
|
|
|
|
|
num_classes = 2 |
|
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) |
|
|
|
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
|
model = nn.DataParallel(model) |
|
model = model.to(device) |
|
params = sum(p.numel() for p in model.parameters()) |
|
print("num params ",params) |
|
|
|
|
|
model_path = 'models/model-47-99.125.pt' |
|
|
|
model.load_state_dict(torch.load(model_path, weights_only=True)) |
|
model = model.eval() |
|
|
|
|
|
import sigpyproc.readers as r |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from scipy.special import softmax |
|
from tqdm import tqdm |
|
|
|
all_detections = [] |
|
|
|
|
|
print("Processing first file (SNR 180)...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60398_67123_110077819_frb20240114a_0001/LoC.C0736/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(27085468,27397968, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9982: |
|
key = data.cpu().numpy() |
|
all_detections.append({ |
|
'data': key, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60398_67123_110077819_frb20240114a_0001 (SNR 180)', |
|
'normalization': 'raw', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 1 detections: {num_pos}") |
|
|
|
|
|
print("Processing second file (SNR 60)...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60428_58167_24730285_frb20240114a_0001/LoC.C1504/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(8148984,8461484, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9988: |
|
key = data.cpu().numpy() |
|
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
|
all_detections.append({ |
|
'data': key/result, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60428_58167_24730285_frb20240114a_0001 (SNR 60)', |
|
'normalization': 'normalized', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 2 detections: {num_pos}") |
|
|
|
|
|
print("Processing third file...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60427_42703_18513000_frb20240114a_0001/LoC.C1504/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(20343125,20655625, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9988: |
|
key = data.cpu().numpy() |
|
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
|
all_detections.append({ |
|
'data': key/result, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60427_42703_18513000_frb20240114a_0001', |
|
'normalization': 'normalized', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 3 detections: {num_pos}") |
|
|
|
|
|
print("Processing fourth file...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60395_72956_94613525_frb20240114a_0001/LoB.C1312/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(8708515,9021015, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9988: |
|
key = data.cpu().numpy() |
|
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
|
all_detections.append({ |
|
'data': key/result, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60395_72956_94613525_frb20240114a_0001', |
|
'normalization': 'normalized', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 4 detections: {num_pos}") |
|
|
|
|
|
print("Processing fifth file...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60429_47342_29343017_frb20240114a_0001/LoB.C1120/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(10399062,10711562, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9988: |
|
key = data.cpu().numpy() |
|
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
|
all_detections.append({ |
|
'data': key/result, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60429_47342_29343017_frb20240114a_0001', |
|
'normalization': 'normalized', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 5 detections: {num_pos}") |
|
|
|
|
|
print("Processing sixth file...") |
|
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60456_42557_118616821_frb20240114a_0001/LoC.C1312/decimated.fil') |
|
header = fil.header |
|
print(header) |
|
|
|
triggers = [] |
|
counter = 0 |
|
for i in tqdm(range(1250000,1562500, 2048)): |
|
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda() |
|
out = model(transform(torch.tensor(data).cuda())[None]) |
|
out = softmax(out.detach().cpu().numpy(), axis=1) |
|
triggers.append(out) |
|
counter += 1 |
|
if out[0, 1]>0.9988: |
|
key = data.cpu().numpy() |
|
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1) |
|
all_detections.append({ |
|
'data': key/result, |
|
'confidence': out[0, 1], |
|
'file_index': i, |
|
'file_name': 'fil_60456_42557_118616821_frb20240114a_0001', |
|
'normalization': 'normalized', |
|
'header': header |
|
}) |
|
stack = np.stack(triggers) |
|
positives = stack[:,0,1] |
|
num_pos = np.where(positives > 0.9988)[0].shape[0] |
|
print(f"File 6 detections: {num_pos}") |
|
|
|
|
|
print(f"\nTotal detections found: {len(all_detections)}") |
|
|
|
if len(all_detections) > 0: |
|
|
|
all_detections.sort(key=lambda x: x['confidence'], reverse=True) |
|
|
|
|
|
n_detections = len(all_detections) |
|
cols = 2 |
|
rows = 5 |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(10, 12)) |
|
|
|
|
|
axes_flat = axes.flatten() |
|
|
|
for idx, detection in enumerate(all_detections): |
|
ax = axes_flat[idx] |
|
|
|
|
|
data_median = np.median(detection['data']) |
|
im = ax.imshow(detection['data'], aspect=6, cmap='hot', vmin=data_median) |
|
|
|
|
|
|
|
time_increment = 6.5e-5 |
|
n_samples = detection['data'].shape[1] |
|
total_time = n_samples * time_increment |
|
|
|
|
|
n_ticks = 5 |
|
tick_positions = np.linspace(0, n_samples-1, n_ticks) |
|
tick_labels = [f"{i*time_increment:.2f}" for i in tick_positions] |
|
|
|
ax.set_xticks(tick_positions) |
|
ax.set_xticklabels(tick_labels, fontsize=12) |
|
|
|
|
|
if idx >= 8: |
|
ax.set_xlabel('Time (seconds)', fontsize=14) |
|
|
|
|
|
header = detection['header'] |
|
fch1 = header.fch1 |
|
foff = header.foff |
|
nchans = header.nchans |
|
|
|
|
|
freq_start = fch1 |
|
freq_end = fch1 + (nchans - 1) * foff |
|
|
|
|
|
n_freq_ticks = 5 |
|
freq_tick_positions = np.linspace(0, nchans-1, n_freq_ticks) |
|
freq_values = [fch1 + i * foff for i in freq_tick_positions] |
|
freq_labels = [f"{freq:.1f}" for freq in freq_values] |
|
|
|
ax.set_yticks(freq_tick_positions) |
|
ax.set_yticklabels(freq_labels, fontsize=12) |
|
|
|
|
|
if idx % 2 == 0: |
|
ax.set_ylabel('Freq. (MHz)', fontsize=14) |
|
|
|
|
|
ax.tick_params(axis='both', which='major', size=3) |
|
|
|
|
|
for idx in range(n_detections, len(axes_flat)): |
|
axes_flat[idx].set_visible(False) |
|
|
|
|
|
plt.subplots_adjust(hspace=0.3, wspace=0.2) |
|
plt.savefig('combined_frb_detections.pdf', dpi=150, bbox_inches='tight', format='pdf') |
|
plt.show() |
|
|
|
print(f"Combined plot saved as 'combined_frb_detections.png'") |
|
|
|
|
|
print("\nDetection Summary:") |
|
for i, detection in enumerate(all_detections): |
|
print(f"{i+1}. {detection['file_name'][:50]}... - Confidence: {detection['confidence']:.4f}") |
|
else: |
|
print("No detections found across all files.") |