DACVV / app.py
vishal323's picture
Update app.py
c0ad3c2
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import gradio as gr
from math import sqrt
from datasets import load_dataset
df = load_dataset("csv", data_files = "vishal323/heart.csv")
df
df.info()
cp_data= df['cp'].value_counts().reset_index()
cp_data['index'][3]= 'asymptomatic'
cp_data['index'][2]= 'non-anginal'
cp_data['index'][1]= 'Atyppical Anigma'
cp_data['index'][0]= 'Typical Anigma'
cp_data
ecg_data= df['restecg'].value_counts().reset_index()
ecg_data['index'][0]= 'normal'
ecg_data['index'][1]= 'having ST-T wave abnormality'
ecg_data['index'][2]= 'showing probable or definite left ventricular hypertrophy by Estes'
ecg_data
def outbreak(feature):
fig = plt.figure()
plt.rcParams.update({'font.size': 10})
plt.rc('xtick', labelsize=5)
if (feature == "Age"):
plt.title("Age of Patients")
plt.xlabel("Age")
sns.countplot(x='age',data=df);
return fig
elif (feature == "Sex"):
plt.title("Sex of Patients,0=Female and 1=Male")
sns.countplot(x='sex',data=df);
return fig
elif (feature == "Chest Pain"):
plt.title("Chest Pain of Patients")
sns.barplot(x=cp_data['index'],y= cp_data['cp']);
return fig
elif (feature == "ECG"):
plt.title("ECG data of Patients")
sns.barplot(x=ecg_data['index'],y= ecg_data['restecg']);
return fig
elif (feature == "Blood Pressure"):
plt.title("Resting Blood Pressure (mmHg)")
sns.distplot(df['trestbps'], kde=True, color = 'magenta')
plt.xlabel("Resting Blood Pressure (mmHg)")
return fig
def op(target, sex, cp, age, bp, ch):
fig = plt.figure()
plt.rcParams.update({'font.size': 10})
plt.rc('xtick', labelsize=5)
print(target, sex, cp, age, bp, ch)
data = df[((df['target'] == 1) & df['sex'] == sex) & (df['cp'] == cp) & (df['age'] >= age) & (df['trestbps'] >= bp) & (df['chol'] >= ch) ]
if (data.empty):
return fig
if (target == "Age"):
plt.title("Count of age of diseased people")
plt.xlabel("Age")
sns.countplot(x='age',data=data);
return fig
elif (target == "Sex"):
plt.title("Count of sex of diseased people")
plt.xlabel("Sex")
sns.countplot(x='sex',data=data);
return fig
if (target == "Chest Pain"):
plt.title("Count of diseased people with cheast pain")
plt.xlabel("Chest Pain")
sns.countplot(x='cp',data=data);
return fig
if (target == "ECG"):
plt.title("Count of people with low glucose")
plt.xlabel("ECG")
sns.countplot(x='restecg',data=data);
return fig
if (target == "Blood Pressure"):
plt.title("Count of diseased people with high BP")
plt.xlabel("BP")
sns.countplot(x='trestbps',data=data);
return fig
def prd(model, age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal):
if model == "Random Forest":
filename = 'DACVV/randomforest.pkl'
X_test = np.array([[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],[52, 1, 0, 125, 212, 0, 1, 168, 0, 1.0, 2, 2, 3]])
loaded_model = pickle.load(open(filename, 'rb'))
result = loaded_model.predict(X_test)[0]
else:
filename = 'DACVV/scaling.pkl'
X_test = np.array([[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],[52, 1, 0, 125, 212, 0, 1, 168, 0, 1.0, 2, 2, 3]])
loaded_model = pickle.load(open(filename, 'rb'))
result = loaded_model.predict(X_test)[0]
return "πŸ˜”, You may a Heart Disease" if (result == 1) else "😁, You are Healthy!!!"
inputs = gr.Dropdown(["Age", "Sex", "Chest Pain", "ECG", "Blood Pressure"], label="Input Feature")
outputs = gr.Plot()
visualisation = gr.Interface(
fn=outbreak,
inputs=inputs,
outputs=outputs,
)
vis = gr.Interface(
inputs = [
gr.Radio(["Age", "Sex", "Chest Pain", "ECG", "Blood Pressure"], label = "Target Feature"),
gr.Radio([1, 0], label = "Sex"),
gr.Radio([0,1,2,3], label = "Chest Pain"),
gr.Slider(25, 80, value=50, step = 1, label = "Age"),
gr.Slider(94, 200, value=150, step = 1, label = "Blood Pressure"),
gr.Slider(126, 564, value=130, step = 1, label = "Cholestrol")
],
fn=op,
outputs = gr.Plot(),
examples=[
["Age", 1, 2, 50, 100, 222],
["Sex", 0, 1, 30, 150, 322],
["Chest Pain", 1, 0, 40, 120, 422],
["ECG", 1, 3, 70, 98, 522],
["Blood Pressure", 0, 1, 28, 170, 262],
]
)
pred = gr.Interface(
inputs = [gr.Radio(["Random Forest", "Scaler"], label = "Model"),
"number",
gr.Radio([0, 1], label = "Sex"),
gr.Radio([0,1,2,3], label = "Chest Pain"),
gr.Slider(94, 200, value=150, step = 1, label = "Blood Pressure"),
gr.Slider(126, 564, value=130, step = 1, label = "Cholestrol"),
gr.Radio([0, 1], label = "FBS"),
gr.Radio([0, 1, 2], label = "RestECG"),
gr.Slider(71, 202, value=50, step = 1, label = "Thalach"),
gr.Radio([0, 1], label = "exang"),
gr.Slider(0, 6.2, value=3, label = "OldPeak"),
gr.Radio([0, 1, 2], label = "Slope"),
gr.Slider(0, 4, value=3, step = 1, label = "CA"),
gr.Slider(0, 3, value=50, step = 1, label = "Thal"),
],
fn=prd,
outputs = "text",
examples=[
["Random Forest", 52, 1, 0, 125, 212, 0, 1, 168, 0, 1.0, 2, 2, 3],
["Scaler", 62, 0, 0, 138, 294, 1, 1, 106, 0, 1.9, 1, 3, 2],
["Random Forest", 44, 0, 2, 108, 141, 0, 1, 175, 0, 0.6, 1, 0, 2],
["Scaler", 68, 0, 2, 120, 211, 0, 0, 115, 0, 1.5, 1, 0, 2]
]
)
interface = gr.TabbedInterface([visualisation, vis, pred], ["Visualisation", "Real Time Analysis", "Predictions"])
interface.launch(inline = False)