BLADE_FRBNN / models /recover_real_frb.py
peterma02's picture
Upload folder using huggingface_hub
f3972ea verified
from utils import CustomDataset, transform, Convert_ONNX
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from resnet_model_mask import ResidualBlock, ResNet
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pickle
torch.manual_seed(1)
# torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(num_gpus)
# Create custom dataset instance
data_dir = '/mnt/buf0/pma/frbnn/train_ready'
dataset = CustomDataset(data_dir, transform=transform)
valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
num_classes = 2
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
model = nn.DataParallel(model)
model = model.to(device)
params = sum(p.numel() for p in model.parameters())
print("num params ",params)
model_path = 'models/model-47-99.125.pt'
model.load_state_dict(torch.load(model_path, weights_only=True))
model = model.eval()
# Collect all plotting data
import sigpyproc.readers as r
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
from tqdm import tqdm
all_detections = []
# first file snr 180
print("Processing first file (SNR 180)...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60398_67123_110077819_frb20240114a_0001/LoC.C0736/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(27085468,27397968, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9982:
key = data.cpu().numpy()
all_detections.append({
'data': key,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60398_67123_110077819_frb20240114a_0001 (SNR 180)',
'normalization': 'raw',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 1 detections: {num_pos}")
# second file snr 60
print("Processing second file (SNR 60)...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60428_58167_24730285_frb20240114a_0001/LoC.C1504/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(8148984,8461484, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9988:
key = data.cpu().numpy()
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1)
all_detections.append({
'data': key/result,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60428_58167_24730285_frb20240114a_0001 (SNR 60)',
'normalization': 'normalized',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 2 detections: {num_pos}")
# third file
print("Processing third file...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60427_42703_18513000_frb20240114a_0001/LoC.C1504/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(20343125,20655625, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9988:
key = data.cpu().numpy()
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1)
all_detections.append({
'data': key/result,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60427_42703_18513000_frb20240114a_0001',
'normalization': 'normalized',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 3 detections: {num_pos}")
# fourth file
print("Processing fourth file...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60395_72956_94613525_frb20240114a_0001/LoB.C1312/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(8708515,9021015, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9988:
key = data.cpu().numpy()
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1)
all_detections.append({
'data': key/result,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60395_72956_94613525_frb20240114a_0001',
'normalization': 'normalized',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 4 detections: {num_pos}")
# fifth file
print("Processing fifth file...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60429_47342_29343017_frb20240114a_0001/LoB.C1120/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(10399062,10711562, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9988:
key = data.cpu().numpy()
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1)
all_detections.append({
'data': key/result,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60429_47342_29343017_frb20240114a_0001',
'normalization': 'normalized',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 5 detections: {num_pos}")
# sixth file
print("Processing sixth file...")
fil = r.FilReader('/mnt/primary/ata/results/p031/FRB20240114a_spliced/fil_60456_42557_118616821_frb20240114a_0001/LoC.C1312/decimated.fil')
header = fil.header
print(header)
triggers = []
counter = 0
for i in tqdm(range(1250000,1562500, 2048)):
data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()
out = model(transform(torch.tensor(data).cuda())[None])
out = softmax(out.detach().cpu().numpy(), axis=1)
triggers.append(out)
counter += 1
if out[0, 1]>0.9988:
key = data.cpu().numpy()
result = np.repeat(np.mean(data.cpu().numpy(), axis = 1)[:,None], 2048, axis=1)
all_detections.append({
'data': key/result,
'confidence': out[0, 1],
'file_index': i,
'file_name': 'fil_60456_42557_118616821_frb20240114a_0001',
'normalization': 'normalized',
'header': header
})
stack = np.stack(triggers)
positives = stack[:,0,1]
num_pos = np.where(positives > 0.9988)[0].shape[0]
print(f"File 6 detections: {num_pos}")
# Create combined plot
print(f"\nTotal detections found: {len(all_detections)}")
if len(all_detections) > 0:
# Sort detections by confidence (highest first)
all_detections.sort(key=lambda x: x['confidence'], reverse=True)
# Create subplots
n_detections = len(all_detections)
cols = 2 # Fixed 2 columns
rows = 5 # Fixed 5 rows
fig, axes = plt.subplots(rows, cols, figsize=(10, 12))
# Flatten axes array to make indexing easier
axes_flat = axes.flatten()
for idx, detection in enumerate(all_detections):
ax = axes_flat[idx]
# Calculate median for better contrast
data_median = np.median(detection['data'])
im = ax.imshow(detection['data'], aspect=6, cmap='hot', vmin=data_median)
# Set proper time axis ticks
# Each sample is 6.5e-5 seconds, and we have 2048 samples
time_increment = 6.5e-5 # seconds per sample
n_samples = detection['data'].shape[1] # should be 2048
total_time = n_samples * time_increment # total time span
# Create time ticks at reasonable intervals
n_ticks = 5 # number of ticks we want
tick_positions = np.linspace(0, n_samples-1, n_ticks)
tick_labels = [f"{i*time_increment:.2f}" for i in tick_positions]
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels, fontsize=12)
# Only add x-axis label for bottom row (row 4 in 0-indexed 5 rows)
if idx >= 8: # Bottom row in 2x5 grid (indices 8 and 9)
ax.set_xlabel('Time (seconds)', fontsize=14)
# Set proper frequency axis ticks using header information
header = detection['header']
fch1 = header.fch1 # frequency of first channel in MHz
foff = header.foff # frequency offset between channels in MHz
nchans = header.nchans # number of channels
# Calculate frequency range
freq_start = fch1
freq_end = fch1 + (nchans - 1) * foff
# Create exactly 5 frequency ticks evenly spaced
n_freq_ticks = 5
freq_tick_positions = np.linspace(0, nchans-1, n_freq_ticks)
freq_values = [fch1 + i * foff for i in freq_tick_positions]
freq_labels = [f"{freq:.1f}" for freq in freq_values]
ax.set_yticks(freq_tick_positions)
ax.set_yticklabels(freq_labels, fontsize=12)
# Only add y-axis label for first column (left column)
if idx % 2 == 0: # First column in 2x5 grid (indices 0, 2, 4, 6, 8)
ax.set_ylabel('Freq. (MHz)', fontsize=14)
# Make tick markers smaller
ax.tick_params(axis='both', which='major', size=3)
# Hide empty subplots
for idx in range(n_detections, len(axes_flat)):
axes_flat[idx].set_visible(False)
# Reduce whitespace between plots
plt.subplots_adjust(hspace=0.3, wspace=0.2)
plt.savefig('combined_frb_detections.pdf', dpi=150, bbox_inches='tight', format='pdf')
plt.show()
print(f"Combined plot saved as 'combined_frb_detections.png'")
# Print summary
print("\nDetection Summary:")
for i, detection in enumerate(all_detections):
print(f"{i+1}. {detection['file_name'][:50]}... - Confidence: {detection['confidence']:.4f}")
else:
print("No detections found across all files.")