Sajin Payandath
commited on
Commit
·
6d378ef
1
Parent(s):
75ced52
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import Gradio and other libraries
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
import pandas as pd
|
11 |
+
import joblib
|
12 |
+
|
13 |
+
from sklearn.preprocessing import MinMaxScaler,LabelEncoder
|
14 |
+
|
15 |
+
path = os.getcwd()
|
16 |
+
n_tab_features = 8
|
17 |
+
num_classes = 11
|
18 |
+
n_pain_cls = 4
|
19 |
+
n_hidden = 128
|
20 |
+
|
21 |
+
scaler = MinMaxScaler(feature_range=(0,1))
|
22 |
+
encoder = LabelEncoder()
|
23 |
+
|
24 |
+
fun_tran= lambda x: [int(y) if y!=1 else 0 for y in np.exp( 5*np.log( [k if k>0 else 1 for k in x.flatten()] ) ) ]
|
25 |
+
|
26 |
+
# Define the model class
|
27 |
+
class DosagePredictionModel(nn.Module):
|
28 |
+
def __init__(self, n_tab_features, n_pain_cls):
|
29 |
+
super(DosagePredictionModel, self).__init__()
|
30 |
+
self.fc1 = nn.Linear(n_tab_features, n_hidden) # n_tab_features input features
|
31 |
+
self.fc2 = nn.Linear(n_hidden, n_hidden)
|
32 |
+
self.fc3 = nn.Linear(n_hidden, n_hidden//2)
|
33 |
+
self.fc4 = nn.Linear(n_hidden//2, n_hidden//2)
|
34 |
+
self.bn5 = nn.BatchNorm1d(n_hidden//2)
|
35 |
+
self.fc6 = nn.Linear(n_hidden//2, num_classes) # Output for Medicine classification
|
36 |
+
self.fc7 = nn.Linear(n_hidden//2, n_pain_cls) # Output for Pain classification
|
37 |
+
self.fc8 = nn.Linear(n_hidden//2, 1) # Output for regression
|
38 |
+
|
39 |
+
|
40 |
+
def forward(self, x, regression_targets=None,classification_targets=None, pain_cls_tgt=None ):
|
41 |
+
x = torch.relu(self.fc1(x))
|
42 |
+
x = torch.relu(self.fc2(x))
|
43 |
+
x = torch.relu(self.fc3(x))
|
44 |
+
x = torch.relu(self.fc4(x))
|
45 |
+
x = self.bn5(x)
|
46 |
+
|
47 |
+
regression_output = self.fc8(x) # Regression output
|
48 |
+
regression_y = regression_output
|
49 |
+
|
50 |
+
# Medicine Classification Layer
|
51 |
+
output_classification = self.fc6(x)
|
52 |
+
output_classification = torch.softmax(output_classification,dim=1)
|
53 |
+
pred_cls_y = torch.argmax(output_classification, dim = 1)
|
54 |
+
|
55 |
+
# Pain Classification Layer
|
56 |
+
pain_output_cls = self.fc7(x)
|
57 |
+
pain_output_cls = torch.softmax(pain_output_cls,dim=1)
|
58 |
+
pred_pain_cls_y = torch.argmax(pain_output_cls, dim = 1)
|
59 |
+
|
60 |
+
if (classification_targets is None) or (regression_targets is None) or (pain_cls_tgt is None):
|
61 |
+
print('Inference Mode')
|
62 |
+
loss, regression_loss, med_cls_loss, pain_cls_loss = None, None, None, None
|
63 |
+
else:
|
64 |
+
loss_fn = nn.SmoothL1Loss() # mean square error
|
65 |
+
regression_loss = loss_fn(regression_output, regression_targets.view(-1,1))
|
66 |
+
#Classification
|
67 |
+
med_cls_loss = nn.CrossEntropyLoss()(output_classification, classification_targets )
|
68 |
+
#pain_class
|
69 |
+
pain_cls_loss = nn.CrossEntropyLoss()(pain_output_cls, pain_cls_tgt )
|
70 |
+
|
71 |
+
weight_dosage, weight_med_cls = 0.2, 0.4
|
72 |
+
loss = weight_dosage * regression_loss + weight_med_cls * med_cls_loss + (1-(weight_dosage+weight_med_cls))*pain_cls_loss
|
73 |
+
return loss,regression_loss,med_cls_loss,pain_cls_loss, regression_y, pred_cls_y, pred_pain_cls_y
|
74 |
+
|
75 |
+
|
76 |
+
# Create an instance of the model
|
77 |
+
model = DosagePredictionModel(n_tab_features, n_pain_cls)
|
78 |
+
|
79 |
+
# Load the model weights from a file (assuming it is saved as "model.pt")
|
80 |
+
model.load_state_dict(torch.load("med_dos_pain_1.pt"))
|
81 |
+
|
82 |
+
# Set the model to evaluation mode
|
83 |
+
model.eval()
|
84 |
+
|
85 |
+
# Define a list of possible medicines
|
86 |
+
medicines = ["Aspirin", "Lisinopril", "Metoprolol", "Hydrochlorothiazide"]
|
87 |
+
dense_features = ['Weight_in_Kgs','Heart_rate','pulse_rate','Systolic_BP','Diastolic_BP','BIS_Value','SPO2']
|
88 |
+
sparse_features = ['Pain_Position']
|
89 |
+
|
90 |
+
# Define a function to get the medicine name from the model output
|
91 |
+
def get_medicine_name(pred_medicine_class):
|
92 |
+
med_encoder = joblib.load(path+'/cat_encoder.joblib')
|
93 |
+
pred_medicine_name = med_encoder.inverse_transform(pred_medicine_class)
|
94 |
+
# Return the corresponding medicine name
|
95 |
+
return pred_medicine_name
|
96 |
+
|
97 |
+
# Define a function to get the model input from the user input
|
98 |
+
def get_model_input(Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2, Pain_Position):
|
99 |
+
values = [[Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2]]
|
100 |
+
cat_values = [[Pain_Position]]
|
101 |
+
all_values = [[Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2,Pain_Position]]
|
102 |
+
temp = pd.DataFrame(all_values, columns= dense_features + sparse_features)
|
103 |
+
# print(temp.dtypes)
|
104 |
+
# Normalize the dense feature values to be between 0 and 1
|
105 |
+
scaler = joblib.load(path+'/scaler.joblib')
|
106 |
+
temp[dense_features] = scaler.transform( temp[dense_features] )
|
107 |
+
print(temp)
|
108 |
+
|
109 |
+
# Encode the categorical features 0, 1 etc.
|
110 |
+
painpos_encoder = joblib.load(path+'/cat_pain_level_encoder.joblib')
|
111 |
+
temp[sparse_features] = painpos_encoder.transform( cat_values[0] )
|
112 |
+
|
113 |
+
# Create a tensor from the normalized values
|
114 |
+
input = torch.tensor(temp.to_numpy()).float()
|
115 |
+
print(input)
|
116 |
+
# Return the input tensor
|
117 |
+
return input
|
118 |
+
|
119 |
+
# Define a function to get the prediction from the user input
|
120 |
+
def predict(Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2, Pain_Position):
|
121 |
+
# Get the model input from the user input
|
122 |
+
X = get_model_input(Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2, Pain_Position)
|
123 |
+
print(X.dtype)
|
124 |
+
# Get the model output from the model input
|
125 |
+
_, _, _, _, pred_dosage, pred_medicine_class, pred_pain_cls = model(X)
|
126 |
+
# Get the medicine name from the model output
|
127 |
+
medicine = get_medicine_name(pred_medicine_class)
|
128 |
+
# Return the predicted medicine name as a string
|
129 |
+
return "The predicted medicine for you is: " + medicine
|
130 |
+
|
131 |
+
Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP = 75.3, 85.7, 84, 141
|
132 |
+
Diastolic_BP, BIS_Value, SPO2, Pain_Position = 92, 80, 90, 'Lower Back'
|
133 |
+
predict(Weight_in_Kgs, Heart_rate, pulse_rate, Systolic_BP, Diastolic_BP, BIS_Value, SPO2, Pain_Position)
|
134 |
+
# Create a Gradio interface with two number inputs and a text output
|
135 |
+
interface = gr.Interface(
|
136 |
+
fn=predict,
|
137 |
+
inputs=[gr.inputs.Slider(minimum=40,maximum=120,step=1,label="Weight in Kgs"),
|
138 |
+
gr.inputs.Slider(minimum=40,maximum=105,step=1,label="Heart rate"),
|
139 |
+
gr.inputs.Slider(minimum=60,maximum=100,step=1,label="Pulse rate"),
|
140 |
+
gr.inputs.Slider(minimum=115,maximum=142,step=1,label="Systolic BP"),
|
141 |
+
gr.inputs.Slider(minimum=70,maximum=93,step=1,label="Diastolic BP"),
|
142 |
+
gr.inputs.Slider(minimum=30,maximum=100,step=1,label="BIS Value"),
|
143 |
+
gr.inputs.Slider(minimum=71,maximum=100,step=1,label="SPO2"),
|
144 |
+
gr.inputs.Dropdown(['Lower Back', 'Shoulder ', 'Abdomine', 'Muscle', 'Neck',
|
145 |
+
'Fybromialgia', 'Knee Joint', 'Joint pain', 'Shoulder', 'No Pain'], label="Pain Position", default="Lower Back") ],
|
146 |
+
outputs=gr.outputs.Textbox(label="Medicine") )
|
147 |
+
|
148 |
+
# Launch the interface on Hugging Face Spaces (assuming you have an account)
|
149 |
+
interface.launch()
|
150 |
+
|