leannebriffa commited on
Commit
47fbbc3
·
1 Parent(s): aa640d7

Updated app.py to use deep learning

Browse files
Files changed (1) hide show
  1. app.py +133 -134
app.py CHANGED
@@ -1,137 +1,136 @@
1
- import random
2
  import gradio as gr
3
- from joblib import load
4
-
5
- # Model URL for each
6
- rf_model_url = 'random_forest_model.joblib'
7
-
8
-
9
- # Load Model method
10
- def load_model(url):
11
- return load(rf_model_url)
12
-
13
-
14
- def bool_value(val):
15
- if not val:
16
- return 0
17
- else:
18
- return 1
19
-
20
-
21
- def gender(val):
22
- if val == 'Male':
23
- return 1
24
- elif val == 'Female':
25
- return 0
26
- else:
27
- return 2
28
-
29
-
30
- def race(val):
31
- if val == 'Hispanic':
32
- return 2
33
- elif val == 'Black':
34
- return 1
35
- elif val == 'White':
36
- return 5
37
- elif val == 'Other':
38
- return 4
39
- elif val == 'Asian':
40
- return 0
41
- else:
42
- return 3
43
-
44
-
45
- def search_outcome(val, end_range):
46
- if not val:
47
- return 2 # Which mean no search was conducted
48
- else:
49
- return random.randrange(0, end_range)
50
-
51
-
52
- def search_reason(val, end_range):
53
- if not val:
54
- return 726 # Which mean no search was conducted
55
- else:
56
- return random.randrange(0, end_range)
57
-
58
-
59
- # Make Prediction Model but would also like to add Gender and Race
60
- def make_predication(year_stopped, offences, search, doc, car_year, alcohol, safety, genders, speeding,
61
- races, accident, actual_accident, damage, road_signs, injury, belt, disobedience, bad_driving,
62
- phone):
63
- input_features = {
64
- 'encoded_Search Outcome': search_outcome(search, 6),
65
- 'encoded_Search Reason For Stop': search_reason(search, 727),
66
- 'Number Of Offences': offences,
67
- 'Search Conducted': bool_value(search),
68
- 'Year Stopped': year_stopped,
69
- 'encoded_SubAgency': random.randrange(0, 7),
70
- 'encoded_Arrest Type': random.randrange(0, 18),
71
- 'Invalid Documentation': bool_value(doc),
72
- 'Year': car_year, # int
73
- 'encoded_Driver City': random.randrange(0, 8114), # int
74
- 'encoded_Make': random.randrange(0, 57), # int
75
- 'Alcohol': bool_value(alcohol), # bool
76
- 'encoded_Color': random.randrange(0, 25), # int
77
- 'Vehicle Safety And Standards': bool_value(safety), # bool
78
- 'Speeding': bool_value(speeding),
79
- 'encoded_Race': race(races),
80
- 'Contributed To Accident': bool_value(accident),
81
- 'Accident': bool_value(actual_accident),
82
- 'encoded_VehicleType': random.randrange(0, 31),
83
- 'Property Damage': bool_value(damage),
84
- 'Road Signs And Markings': bool_value(road_signs),
85
- 'Personal Injury': bool_value(injury),
86
- 'encoded_Gender': gender(genders),
87
- 'Belts': bool_value(belt),
88
- 'Disobedience': bool_value(disobedience),
89
- 'encoded_DL State': random.randrange(0, 70),
90
- 'encoded_State': random.randrange(0, 68),
91
- 'Negligent Driving': bool_value(bad_driving),
92
- 'encoded_Driver State': random.randrange(0, 67),
93
- 'Mobile Phone': bool_value(phone)
94
- }
95
-
96
- x_input_feature = [[input_features[feature] for feature in sorted(input_features)]]
97
- rfc_model = load_model(rf_model_url)
98
- prd = rfc_model.predict(x_input_feature)
99
-
100
- if prd == 0:
101
- return 'Citation'
102
- elif prd == 1:
103
- return 'SERO'
104
- else:
105
- return 'Warning'
106
-
107
-
108
- iface = gr.Interface(fn=make_predication,
109
- inputs=[gr.components.Slider(minimum=2010, maximum=2023, step=1, label='Citation Year'),
110
- gr.components.Slider(minimum=1, step=1, label='Number of offences found for stop'),
111
- gr.components.Checkbox(label='Was a search conducted'),
112
- gr.components.Checkbox(label='Any invalid documents'),
113
- gr.components.Slider(minimum=1990, maximum=2023, step=1, label='Make Year'),
114
- gr.components.Checkbox(label='Was alcohol involved'),
115
- gr.components.Checkbox(label='Safety Standard issues in Vehicle'),
116
- gr.components.Dropdown(label='Gender', choices=['Female',
117
- 'Male',
118
- 'Unknown']),
119
- gr.components.Checkbox(label='Speeding'),
120
- gr.components.Dropdown(label='Race', choices=['Asian',
121
- 'Black',
122
- 'Hispanic',
123
- 'Native American',
124
- 'Other',
125
- 'White']),
126
- gr.components.Checkbox(label='Irregularities which could contribute to accident'),
127
- gr.components.Checkbox(label='Was the stop actual for an accident'),
128
- gr.components.Checkbox(label='Any Property damage'),
129
- gr.components.Checkbox(label='Road sign violation involvement'),
130
- gr.components.Checkbox(label='Any injuries'),
131
- gr.components.Checkbox(label='Seat Belts irregulation'),
132
- gr.components.Checkbox(label='Any disobedience'),
133
- gr.components.Checkbox(label='Bad driving'),
134
- gr.components.Checkbox(label='Mobile phone')],
135
  outputs=["text"])
136
 
137
- iface.launch(debug=True)
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from tensorflow import keras
5
+ import pickle
6
+
7
+ # Configuration section START
8
+
9
+ with open('model_config.pkl', 'rb') as f:
10
+ tf = pickle.load(f)
11
+
12
+ model = keras.models.load_model('model.keras')
13
+
14
+ all_selected_features = ['Alcohol', 'Arrest Type', 'Belts', 'Contributed To Accident', 'Disobedience', 'Driver City',
15
+ 'Gender',
16
+ 'Invalid Documentation', 'Make', 'Mobile Phone', 'Negligent Driving', 'Number Of Offences',
17
+ 'Personal Injury',
18
+ 'Property Damage', 'Race', 'Road Signs And Markings', 'Search Outcome', 'Speeding',
19
+ 'Stop Hour', 'Stop Year',
20
+ 'SubAgency', 'Vehicle Safety And Standards', 'VehicleType', 'Year']
21
+ very_high_cardinality_features = ['Driver City']
22
+ high_cardinality_features = ['Make', 'VehicleType']
23
+ bools = ['Alcohol', 'Belts', 'Contributed To Accident', 'Disobedience', 'Invalid Documentation', 'Mobile Phone',
24
+ 'Negligent Driving',
25
+ 'Personal Injury', 'Property Damage', 'Road Signs And Markings', 'Speeding', 'Vehicle Safety And Standards']
26
+ num_cols = ['Driver City', 'Make', 'Number Of Offences', 'Stop Hour', 'Stop Year', 'VehicleType', 'Year']
27
+
28
+
29
+ # Configuration section END
30
+
31
+
32
+ make_prediction(alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
33
+ invalid_documentation, make, mobile_phone, negligent_driving, number_of_offences, personal_injury,
34
+ property_damage, race, road_signs_and_markings, search_outcome, speeding, stop_hour, stop_year,
35
+ subagency, vehicle_safety_and_standards, vehicletype, year):
36
+ """
37
+ Function to predict the 'Violation Type' of an individual sample of traffic stop:
38
+ :param alcohol: boolean
39
+ :param arrest_type: String
40
+ :param belts: Boolean
41
+ :param contributed_to_accident: Boolean
42
+ :param disobedience: Boolean
43
+ :param driver_city: String
44
+ :param gender: 'M', 'F' or 'N'
45
+ :param invalid_documentation: Boolean
46
+ :param make: String
47
+ :param mobile_phone: Boolean
48
+ :param negligent_driving: Boolean
49
+ :param number_of_offences: Integer
50
+ :param personal_injury: Boolean
51
+ :param property_damage: Boolean
52
+ :param race: One of 'HISPANIC', 'BLACK', 'WHITE', 'OTHER', 'ASIAN', or 'NATIVE AMERICAN'
53
+ :param road_signs_and_markings: Boolean
54
+ :param search_outcome: String
55
+ :param speeding: Boolean
56
+ :param stop_hour: Integer
57
+ :param stop_year: Integer
58
+ :param subagency: String
59
+ :param vehicle_safety_and_standards: Boolean
60
+ :param vehicletype: String
61
+ :param year: Integer
62
+ :return:
63
+ """
64
+ # Create a dataframe with the feature values
65
+ X = pd.DataFrame([[alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
66
+ invalid_documentation, make,
67
+ mobile_phone, negligent_driving, number_of_offences, personal_injury, property_damage, race,
68
+ road_signs_and_markings,
69
+ search_outcome, speeding, stop_hour, stop_year, subagency, vehicle_safety_and_standards,
70
+ vehicletype, year]],
71
+ columns=all_selected_features)
72
+
73
+ # Transform the features
74
+ # Encode very high cardinality features (Driver City) with ordinal encoding, by fitting on the whole dataset
75
+ oev = tf["OrdinalEncoder_VeryHighCardinality"].set_params(handle_unknown='use_encoded_value').set_params(
76
+ unknown_value=6714)
77
+ X[very_high_cardinality_features] = oev.transform(X[very_high_cardinality_features])
78
+
79
+ # Encode high cardinality features with ordinal encoding by fitting only on the training set
80
+ X[high_cardinality_features] = tf["OrdinalEncoder_HighCardinality"].transform(X[high_cardinality_features])
81
+
82
+ # Scale all the numerical features, including those that were ordinal encoded
83
+ X[num_cols] = tf["StandardScaler"].transform(X[num_cols])
84
+
85
+ # Convert booleans to numbers
86
+ X[bools] = X[bools].astype('int8')
87
+
88
+ # One-hot encode the low cardinality features
89
+ X = tf['OneHotEncoder'].transform(X)
90
+
91
+ # Make the prediction using the model
92
+ prediction = model.predict(X, verbose=0)
93
+
94
+ prediction = tf["LabelEncoder"].inverse_transform(np.argmax(prediction, axis=1))
95
+
96
+ # Return the prediction
97
+ return prediction[0]
98
+
99
+ arrest_types = ['A - Marked Patrol', 'G - Marked Moving Radar (Stationary)',
100
+ 'Q - Marked Laser', 'L - Motorcycle',
101
+ 'H - Unmarked Moving Radar (Stationary)', 'O - Foot Patrol',
102
+ 'E - Marked Stationary Radar', 'B - Unmarked Patrol',
103
+ 'S - License Plate Recognition', 'R - Unmarked Laser',
104
+ 'J - Unmarked Moving Radar (Moving)', 'M - Marked (Off-Duty)',
105
+ 'I - Marked Moving Radar (Moving)',
106
+ 'F - Unmarked Stationary Radar', 'D - Unmarked VASCAR',
107
+ 'K - Aircraft Assist', 'C - Marked VASCAR', 'P - Mounted Patrol',
108
+ 'N - Unmarked (Off-Duty)']
109
+
110
+ iface = gr.Interface(fn=make_prediction,
111
+ inputs=[gr.components.Checkbox(label='Was the driver under the influence of alcohol?'),
112
+ gr.components.Dropdown(label='Choose the arrest type', choices=arrest_types, value='A - Marked Patrol')
113
+ gr.components.Checkbox(label='Were seatbelts used appropriately?'),
114
+ gr.components.Checkbox(label='Did the driver actions contribute to an accident?'),
115
+ gr.components.Checkbox(label='Was the driver disobedient? (such as failing to display documentation upon request)?'),
116
+ gr.components.Dropdown(label='Choose the arrest type', choices=tf['OrdinalEncoder_VeryHighCardinality'].categories_, value='SILVER SPRING')
117
+ gr.components.Dropdown(label='Driver Gender', choices=['M', 'F', 'N'], value='M')
118
+ gr.components.Checkbox(label='Was the driver driving with Invalid Documentation (such as suspended registration, suspended license, expired registration plates and validation tabs or expired license plate)?'),
119
+ gr.components.Dropdown(label='Vehicle Make', choices=tf['OrdinalEncoder_HighCardinality'].categories_[0], value='TOYOTA')
120
+ gr.components.Checkbox(label='Was the driver using a mobile phone while driving?'),
121
+ gr.components.Checkbox(label='Was the driver caught driving with negligence (example switching lanes in an unsafe manner)?'),
122
+ gr.components.Slider(minimum=1, step=1, label='Number of offences committed')],
123
+ gr.components.Checkbox(label='Did the violation involve any personal injury?'),
124
+ gr.components.Checkbox(label='Did the violation involve any property damage?'),
125
+ gr.components.Dropdown(label='Choose the race of the driver', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[2], value='WHITE')
126
+ gr.components.Checkbox(label='Did the driver fail to obey signs and markings (such as traffic control device instructions, stop lights, red signal and stop sign lines)?'),
127
+ gr.components.Dropdown(label='Choose the race of the driver', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[3], value='NO SEARCH CONDUCTED')
128
+ gr.components.Checkbox(label='Was the driver caught speeding?'),
129
+ gr.components.Slider(maximum=23, step=1, label='Time HOUR when stop occurred in 24-hour format')],
130
+ gr.components.Slider(minimum=2012, maximum=2024, step=1, label='Year when stop occurred')],
131
+ gr.components.Dropdown(label='What is the name of the subagency that conducted the traffic stop?', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[1], value='4th District, Wheaton')
132
+ gr.components.Checkbox(label='Was the vehicle safe and up to standards (lights properly switched, registration plates attached etc.)?'),
133
+ gr.components.Slider(minimum=1970, maximum=2023, step=1, label='Year of manufacture of the vehicle:)],
134
  outputs=["text"])
135
 
136
+ iface.launch(debug=True)