Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch import nn | |
from model_def import NeuralNetwork | |
labels_map = { | |
0: 'sinus_rhythm', | |
1: 'atrial_fibrillation', | |
2: 'av_block', | |
3: 'bradycardia', | |
4: 'sinus_arrhythmia', | |
5: 'sinus_rhythm-sinus_arrhythmia', | |
6: 'sinus_rhythm-av_block' | |
} | |
PATH = 'model/seven-diseases/' #ResNet-lead-0.pth' | |
lead_1 = f"{PATH}/ResNet-lead-0-BEST.pth" | |
lead_2 = f"{PATH}/ResNet-lead-1-BEST.pth" | |
lead_3 = f"{PATH}/ResNet-lead-2-BEST.pth" | |
# PATH = 'model/' #ResNet-lead-0.pth' | |
# lead_1 = f"{PATH}/ResNet-lead-0.pth" | |
# lead_2 = f"{PATH}/ResNet-lead-1.pth" | |
# lead_3 = f"{PATH}/ResNet-lead-2.pth" | |
lead_1_model = NeuralNetwork() | |
lead_1_model.load_state_dict(torch.load(lead_1, map_location=torch.device('cpu'))) | |
lead_2_model = NeuralNetwork() | |
lead_2_model.load_state_dict(torch.load(lead_2, map_location=torch.device('cpu'))) | |
lead_3_model = NeuralNetwork() | |
lead_3_model.load_state_dict(torch.load(lead_3, map_location=torch.device('cpu'))) | |
def helper(sig, model): | |
inpt = sig[:, np.newaxis, :] | |
with torch.no_grad(): | |
res = model(torch.from_numpy(inpt)) | |
res = torch.exp(res) | |
prediction_scores = res.numpy() | |
return prediction_scores | |
def predict_disease(lead1, lead2, lead3): | |
p1 = helper(lead1, lead_1_model) | |
p2 = helper(lead2, lead_2_model) | |
p3 = helper(lead3, lead_3_model) | |
# print(p1.argmax(axis=1)) | |
# print(p2.argmax(axis=1)) | |
# print(p3.argmax(axis=1)) | |
p_avg = (p1 + p2 + p3)/3 | |
# print(p_avg.argmax(axis=1)) | |
# print("-----------------") | |
return p_avg, p1, p2, p3 |