Spaces:
Sleeping
Sleeping
File size: 5,342 Bytes
6880cdd 80f4283 f113093 6880cdd 6469332 6880cdd 6469332 6880cdd 6469332 6880cdd 1650e3d 6880cdd 98d2076 2528937 6880cdd 2528937 6880cdd 2528937 6880cdd 2528937 6469332 2528937 6880cdd 6469332 6880cdd 6469332 2528937 6880cdd 2528937 6880cdd 2528937 6880cdd 00c0fa0 2528937 d22446e 2528937 6880cdd 6469332 6880cdd 2528937 c210da8 6d7abe9 2528937 6880cdd 6dde7e1 3217373 2528937 6880cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import pandas as pd
import streamlit as st
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_percentage_error
import warnings
warnings.filterwarnings("ignore")
#read files
data = pd.read_csv('owid-monkeypox-data.csv')
data = data[['location','iso_code','date','new_cases','total_cases','new_deaths','total_deaths']]
pop = pd.read_csv('API_SP.POP.TOTL_DS2_en_csv_v2_4578059.csv')
#preprocessiong data
all_location = {}
for i in data['iso_code'].unique():
all_location[i] = data[data['iso_code'] == i].reset_index(drop=True)
popu = pop[['Country Code','2021']].to_dict('index')
pop_dict = {}
for i in popu.values():
pop_dict[i['Country Code']] = i['2021']
pop_dict['GLP'] = 400000
pop_dict['MTQ'] = 376480
pop_dict['OWID_WRL'] = 7836630792
code = dict(data.groupby('location')['iso_code'].unique())
# SIR model differential equations.
def deriv(x, t, beta, gamma):
s, i, r = x
dsdt = -beta * s * i
didt = beta * s * i - gamma * i
drdt = gamma * i
return [dsdt, didt, drdt]
#plot model
def plotdata(t, s, i,r,R0, e=None):
# plot the data
fig = plt.figure(figsize=(12,6))
ax = [fig.add_subplot(221, axisbelow=True),
fig.add_subplot(223),
fig.add_subplot(122)]
ax[0].plot(t, s, lw=3, label='Fraction Susceptible')
ax[0].plot(t, i, lw=3, label='Fraction Infective')
ax[0].plot(t, r, lw=3, label='Recovered')
ax[0].set_title('Susceptible and Recovered Populations')
ax[0].set_xlabel('Time /days')
ax[0].set_ylabel('Fraction')
ax[1].plot(t, i, lw=3, label='Infective')
ax[1].set_title('Infectious Population')
if e is not None: ax[1].plot(t, e, lw=3, label='Exposed')
ax[1].set_ylim(0, 1.0)
ax[1].set_xlabel('Time /days')
ax[1].set_ylabel('Fraction')
ax[2].plot(s, i, lw=3, label='s, i trajectory')
ax[2].plot([1/R0, 1/R0], [0, 1], '--', lw=3, label='di/dt = 0')
ax[2].plot(s[0], i[0], '.', ms=20, label='Initial Condition')
ax[2].plot(s[-1], i[-1], '.', ms=20, label='Final Condition')
ax[2].set_title('State Trajectory')
ax[2].set_aspect('equal')
ax[2].set_ylim(0, 1.05)
ax[2].set_xlim(0, 1.05)
ax[2].set_xlabel('Susceptible')
ax[2].set_ylabel('Infectious')
for a in ax:
a.grid(True)
a.legend()
plt.tight_layout()
return fig
#final model
def SIR(country,R0,t_infective):
#R0 = 0.57 - 1.25
# parameter values
R0 = R0
t_infective = t_infective
# initial number of infected and recovered individuals
i_initial = all_location[country]['total_cases'].iloc[0]/pop_dict[country]
r_initial = 0.00
s_initial = 1 - i_initial - r_initial
gamma = 1/t_infective
beta = R0*gamma
# initial number of infected and recovered individuals
i_initial = all_location[country]['new_cases'].sum()/pop_dict[country]
r_initial = 0.00
s_initial = 1 - i_initial - r_initial
t = np.linspace(0, 3000, 3000)
x_initial = s_initial, i_initial, r_initial
soln = odeint(deriv, x_initial, t, args=(beta, gamma))
s, i, r = soln.T
e = None
scaler = all_location[country]['total_cases'].apply(lambda x : x/pop_dict[country])
rangee = len(all_location[country]['total_cases'])
rmpe = mean_absolute_percentage_error(scaler,i[0:rangee])
return R0,t_infective,beta,gamma,rmpe,plotdata(t, s, i,r,R0)
def compare_plt(country):
fig = plt.figure(figsize=(12,6))
ax = [fig.add_subplot(121, axisbelow=True),fig.add_subplot(122)]
ax[0].set_title('Monkeypox confirmed cases')
ax[0].plot(all_location[country]['total_cases'],lw=3,label='Infective')
ax[0].set_xlabel('Days')
ax[0].set_ylabel('Number of cases')
ax[0].legend()
scaler = all_location[country]['total_cases'].apply(lambda x : x/pop_dict[country])
ax[1].set_title('Monkeypox confirmed cases compare with model')
ax[1].plot(scaler,lw=3,label='Real Infective')
ax[1].plot(i,lw=3,label='SIR model Infective')
ax[1].set_ylim(0,0.00005)
ax[1].set_xlim(0,200)
ax[1].set_xlabel('Days')
ax[1].set_ylabel('Fraction Number of cases')
ax[1].legend()
plt.tight_layout()
return fig
def main():
st.title("SIR Model for Monkeypox")
with st.form("questionaire"):
country = st.selectbox("Country",data['location'].unique())# user's input
recovery = st.slider("How long Monkeypox recover?", 21, 31, 21)
R0 = st.slider("Basic Reproduction Number (R0)", 0.57, 3.00, 0.57)# user's input
country_code = code[country][0]
range = len(all_location['OWID_WRL']['total_cases'])
rmpe = mean_absolute_percentage_error(scaler,i[0:range])
# clicked==True only when the button is clicked
clicked = st.form_submit_button("Show Graph")
if clicked:
# Show SIR
SIR_param = SIR(country_code,R0,recovery)
st.pyplot(SIR_param[-1])
st.pyplot(compare_plt(country_code))
st.success("SIR model parameters for "+str(country)+" is")
st.success("R0 = "+str(SIR_param[0]))
st.success("Beta = "+str(SIR_param[2]))
st.success("Gamma = "+str(SIR_param[3]))
st.success("RMPE = "+str(SIR_param[4]+"%"))
# Run main()
if __name__ == "__main__":
main() |