runaksh's picture
Update app.py
93d1cfd
raw
history blame
4.19 kB
from fastapi import FastAPI, Request, Response
import gradio
import joblib
from xgboost import XGBClassifier
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
import prometheus_client as prom
app = FastAPI()
username = "runaksh"
repo_name = "Patientsurvival-model"
repo_path = username+ '/' + repo_name
xgb_model_loaded = joblib.load("xgboost-model.pkl")
import pandas as pd
test_data = pd.read_csv("test_data.csv")
f1_metric = prom.Gauge('sentiment_f1_score', 'F1 score for random 100 test samples')
precision_metric = prom.Gauge('sentiment_precision_score', 'Precision score for random 100 test samples')
recall_metric = prom.Gauge('sentiment_recall_score', 'Recall score for random 100 test samples')
# Function for response generation
# Load your trained model
def bol_to_int(bol):
if bol==True:
return 1
else:
return 0
# Function for prediction
def predict_death_event(feature1, feature2, feature3,feature4, feature5, feature6, feature7, feature8, feature9, feature10, feature11, feature12):
data = {'age' : [feature1], 'anaemia' : [bol_to_int(feature2)],
'creatinine_phosphokinase' : [feature3],
'diabetes' : [bol_to_int(feature4)],
'ejection_fraction' : [feature5],
'high_blood_pressure' : [bol_to_int(feature6)],
'platelets' : [feature7],
'serum_creatinine' : [feature8],
'serum_sodium' : [feature9],
'sex' : [bol_to_int(feature10)],
'smoking' : [bol_to_int(feature11)],
'time' : [feature12]}
df = pd.DataFrame(data)
y_pred = xgb_model_loaded.predict(df)[0]
return y_pred
# Function for updating metrics
def update_metrics():
test = test_data.sample(100)
test_text = test['Text'].values
test_pred = sentiment_model(list(test_text))
pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]
f1 = f1_score(test['labels'], pred_labels).round(3)
precision = precision_score(test['labels'], pred_labels).round(3)
recall = recall_score(test['labels'], pred_labels).round(3)
f1_metric.set(f1)
precision_metric.set(precision)
recall_metric.set(recall)
@app.get("/metrics")
async def get_metrics():
update_metrics()
return Response(media_type="text/plain", content= prom.generate_latest())
# Gradio interface to generate UI link
title = "Patient Survival Prediction"
description = "Predict survival of patient with heart failure, given their clinical record"
iface = gradio.Interface(fn = predict_death_event,
inputs=[
gradio.components.Slider(30, 100, step=1, label= 'age'),
gradio.components.Radio(["0","1"], label= 'anaemia'),
gradio.components.Slider(1, 10000, step=1, label= 'creatinine_phosphokinase'),
gradio.components.Radio(["0","1"], label= 'diabetes'),
gradio.components.Slider(1, 100, step=1, label= 'ejection_fraction'),
gradio.components.Radio(["0","1"], label= 'high_blood_pressure'),
gradio.components.Number(label= 'platelets'),
gradio.components.Slider(0.1, 10.0, step=0.1, label= 'serum_creatinine'),
gradio.components.Slider(100, 150, step=1, label= 'serum_sodium'),
gradio.components.Radio(["0","1"], label= 'sex'),
gradio.components.Radio(["0","1"], label= 'smoking'),
gradio.components.Slider(1, 300, step=1, label= 'time')],
outputs = [gradio.components.Textbox (label ='DeathEvent')],
title = title,
description = description)
app = gradio.mount_gradio_app(app, iface, path="/")
#iface.launch(server_name = "0.0.0.0", server_port = 8001) # Ref. for parameters: https://www.gradio.app/docs/interface
if __name__ == "__main__":
# Use this for debugging purposes only
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)