namdini commited on
Commit
563b9ed
·
verified ·
1 Parent(s): 15ca59b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -8
app.py CHANGED
@@ -6,6 +6,17 @@ import folium
6
  from folium.plugins import HeatMap, MarkerCluster
7
  from streamlit_folium import st_folium
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  @st.cache_data
10
  def load_and_preprocess_data(file_path):
11
  # Read the data
@@ -69,7 +80,8 @@ def create_severity_violation_chart(df, age_group=None):
69
  color='Severity',
70
  title=f'Crash Severity Distribution by Violation Type - {age_group}',
71
  labels={'count': 'Number of Incidents', 'Violation': 'Violation Type'},
72
- height=600
 
73
  )
74
 
75
  fig.update_layout(
@@ -78,7 +90,8 @@ def create_severity_violation_chart(df, age_group=None):
78
  barmode='stack'
79
  )
80
 
81
- return fig
 
82
 
83
  def get_top_violations(df, age_group):
84
  if age_group == 'All Ages':
@@ -104,6 +117,23 @@ def get_top_violations(df, age_group):
104
 
105
  return violations_df.head()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  @st.cache_data
108
  def create_map(df, selected_year):
109
  filtered_df = df[df['Year'] == selected_year]
@@ -171,10 +201,10 @@ def create_injuries_fatalities_chart(crash_data, unit_type):
171
 
172
  # Reshape the data for easier plotting
173
  injuries = monthly_sum[['Month', 'Totalinjuries']].rename(columns={'Totalinjuries': 'Value'})
174
- injuries['Measure'] = 'Total Injuries'
175
 
176
  fatalities = monthly_sum[['Month', 'Totalfatalities']].rename(columns={'Totalfatalities': 'Value'})
177
- fatalities['Measure'] = 'Total Fatalities'
178
 
179
  combined_data = pd.concat([injuries, fatalities])
180
 
@@ -221,8 +251,8 @@ def create_injuries_fatalities_chart(crash_data, unit_type):
221
  line_chart = alt.Chart(combined_data).mark_line(point=True).encode(
222
  x=alt.X('Month:N', sort=month_order, title='Month'),
223
  y=alt.Y('Value:Q', title='Total Injuries & Fatalities'),
224
- color=alt.Color('Measure:N', title='', scale=alt.Scale(domain=['Total Injuries', 'Total Fatalities'], range=['blue', 'red'])),
225
- tooltip=['Month', 'Measure:N', 'Value:Q']
226
  ).properties(
227
  title=f'Total Injuries and Fatalities by Month for Unit Type Pair: {unit_type}',
228
  width=600,
@@ -356,6 +386,10 @@ def main():
356
 
357
  if 'Weather' not in df.columns:
358
  df['Weather'] = 'Unknown'
 
 
 
 
359
 
360
  # Create tabs for different visualizations
361
  tab1, tab2, tab3, tab4, tab5 = st.tabs(["Crash Statistics", "Crash Map", "Crash Trend", "Crash Injuries/Fatalities","Distribution by Category"])
@@ -366,8 +400,23 @@ def main():
366
  selected_age = st.selectbox('Select Age Group:', age_groups)
367
 
368
  # Create and display chart
369
- fig = create_severity_violation_chart(df, selected_age)
370
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  # Display statistics
373
  if selected_age == 'All Ages':
 
6
  from folium.plugins import HeatMap, MarkerCluster
7
  from streamlit_folium import st_folium
8
 
9
+ # To fix the color scheme in crash stats plot (asked ChatGPT for appropriate colors)
10
+ severity_colors = {
11
+ "No Injury": "#1f77b4",
12
+ "Possible Injury": "#aec7e8",
13
+ "Non Incapacitating Injury": "#ff7f0e",
14
+ "Incapacitating Injury": "#ffbb78",
15
+ "Suspected Minor Injury": "#2ca02c",
16
+ "Suspected Serious Injury": "#98df8a",
17
+ "Fatal": "#d62728",
18
+ }
19
+
20
  @st.cache_data
21
  def load_and_preprocess_data(file_path):
22
  # Read the data
 
80
  color='Severity',
81
  title=f'Crash Severity Distribution by Violation Type - {age_group}',
82
  labels={'count': 'Number of Incidents', 'Violation': 'Violation Type'},
83
+ height=600,
84
+ color_discrete_map=severity_colors, # --> for part 3
85
  )
86
 
87
  fig.update_layout(
 
90
  barmode='stack'
91
  )
92
 
93
+ # return fig
94
+ return fig, violations
95
 
96
  def get_top_violations(df, age_group):
97
  if age_group == 'All Ages':
 
117
 
118
  return violations_df.head()
119
 
120
+ # added interactivity pie plot for part3 (linked with crash stats visualization)
121
+ def create_interactive_pie_chart(violations, selected_violation):
122
+ # Filter data based on selected violation
123
+ filtered_data = violations[violations['Violation'] == selected_violation]
124
+
125
+ # Create a pie chart for severity distribution of the selected violation type
126
+ fig = px.pie(
127
+ filtered_data,
128
+ names='Severity',
129
+ values='count',
130
+ title=f'Severity Level Distribution for Violation: {selected_violation}',
131
+ height=400,
132
+ color_discrete_map=severity_colors
133
+ )
134
+
135
+ return fig
136
+
137
  @st.cache_data
138
  def create_map(df, selected_year):
139
  filtered_df = df[df['Year'] == selected_year]
 
201
 
202
  # Reshape the data for easier plotting
203
  injuries = monthly_sum[['Month', 'Totalinjuries']].rename(columns={'Totalinjuries': 'Value'})
204
+ injuries['Type'] = 'Total Injuries'
205
 
206
  fatalities = monthly_sum[['Month', 'Totalfatalities']].rename(columns={'Totalfatalities': 'Value'})
207
+ fatalities['Type'] = 'Total Fatalities'
208
 
209
  combined_data = pd.concat([injuries, fatalities])
210
 
 
251
  line_chart = alt.Chart(combined_data).mark_line(point=True).encode(
252
  x=alt.X('Month:N', sort=month_order, title='Month'),
253
  y=alt.Y('Value:Q', title='Total Injuries & Fatalities'),
254
+ color=alt.Color('Type:N', title='', scale=alt.Scale(domain=['Total Injuries', 'Total Fatalities'], range=['blue', 'red'])),
255
+ tooltip=['Month', 'Type:N', 'Value:Q']
256
  ).properties(
257
  title=f'Total Injuries and Fatalities by Month for Unit Type Pair: {unit_type}',
258
  width=600,
 
386
 
387
  if 'Weather' not in df.columns:
388
  df['Weather'] = 'Unknown'
389
+
390
+ # Initialize session state to store selected violation --> added for part3 (interactive pie chart)
391
+ if 'selected_violation' not in st.session_state:
392
+ st.session_state['selected_violation'] = None
393
 
394
  # Create tabs for different visualizations
395
  tab1, tab2, tab3, tab4, tab5 = st.tabs(["Crash Statistics", "Crash Map", "Crash Trend", "Crash Injuries/Fatalities","Distribution by Category"])
 
400
  selected_age = st.selectbox('Select Age Group:', age_groups)
401
 
402
  # Create and display chart
403
+ # fig = create_severity_violation_chart(df, selected_age)
404
+ fig, violations = create_severity_violation_chart(df, selected_age) # --> for part3 (interactive pie chart)
405
+ # st.plotly_chart(fig, use_container_width=True)
406
+
407
+ # Display the first bar chart and capture click events using plotly_events
408
+ clicked_points = plotly_events(fig, click_event=True, override_height=600, override_width="100%") # added for part3 (interactive pie chart)
409
+
410
+ # If a bar is clicked, update the selected_violation in session_state --> added for part3 (interactive pie chart)
411
+ if clicked_points:
412
+ selected_violation = clicked_points[0]['x']
413
+ if selected_violation != st.session_state['selected_violation']:
414
+ st.session_state['selected_violation'] = selected_violation
415
+
416
+ # If a violation is selected, display the pie chart --> added for part3 (interactive pie chart)
417
+ if st.session_state['selected_violation']:
418
+ pie_chart = create_interactive_pie_chart(violations, st.session_state['selected_violation'])
419
+ st.plotly_chart(pie_chart, use_container_width=True)
420
 
421
  # Display statistics
422
  if selected_age == 'All Ages':