ecg-classification / postprocessing.py
nabeelraza's picture
Initial Commit
ff522d1
raw
history blame
1.11 kB
import numpy as np
import torch
from torch import nn
from model_def import NeuralNetwork
labels_map = {
0 : "atrial fibrillation",
1 : "sinus arrhythmia",
2 : "bradycardia",
3 : "1st degree av block",
4 : "sinus rhythm",
}
PATH = 'model/' #ResNet-lead-0.pth'
lead_1_model = NeuralNetwork()
lead_1_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-0.pth", map_location=torch.device('cpu')))
lead_2_model = NeuralNetwork()
lead_2_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-1.pth", map_location=torch.device('cpu')))
lead_3_model = NeuralNetwork()
lead_3_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-2.pth", map_location=torch.device('cpu')))
def helper(sig, model):
inpt = sig[:, np.newaxis, :]
with torch.no_grad():
res = model(torch.from_numpy(inpt))
res = torch.exp(res)
prediction_scores = res.numpy()
return prediction_scores
def make_predictions_indi(lead1, lead2, lead3):
p1 = helper(lead1, lead_1_model)
p2 = helper(lead2, lead_2_model)
p3 = helper(lead3, lead_3_model)
p_avg = (p1 + p2 + p3)/3
return p_avg