leannebriffa commited on
Commit
729ccd9
·
1 Parent(s): 131514f

Fixed syntax errors in app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -116
app.py CHANGED
@@ -1,136 +1,51 @@
1
  import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
- from tensorflow import keras
5
  import pickle
6
 
7
- # Configuration section START
 
8
 
9
- with open('model_config.pkl', 'rb') as f:
10
- tf = pickle.load(f)
 
 
 
11
 
12
- model = keras.models.load_model('model.keras')
13
 
14
- all_selected_features = ['Alcohol', 'Arrest Type', 'Belts', 'Contributed To Accident', 'Disobedience', 'Driver City',
15
- 'Gender',
16
- 'Invalid Documentation', 'Make', 'Mobile Phone', 'Negligent Driving', 'Number Of Offences',
17
- 'Personal Injury',
18
- 'Property Damage', 'Race', 'Road Signs And Markings', 'Search Outcome', 'Speeding',
19
- 'Stop Hour', 'Stop Year',
20
- 'SubAgency', 'Vehicle Safety And Standards', 'VehicleType', 'Year']
21
- very_high_cardinality_features = ['Driver City']
22
- high_cardinality_features = ['Make', 'VehicleType']
23
- bools = ['Alcohol', 'Belts', 'Contributed To Accident', 'Disobedience', 'Invalid Documentation', 'Mobile Phone',
24
- 'Negligent Driving',
25
- 'Personal Injury', 'Property Damage', 'Road Signs And Markings', 'Speeding', 'Vehicle Safety And Standards']
26
- num_cols = ['Driver City', 'Make', 'Number Of Offences', 'Stop Hour', 'Stop Year', 'VehicleType', 'Year']
27
 
 
 
 
28
 
29
- # Configuration section END
 
 
 
30
 
 
 
 
 
 
 
31
 
32
- make_prediction(alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
33
- invalid_documentation, make, mobile_phone, negligent_driving, number_of_offences, personal_injury,
34
- property_damage, race, road_signs_and_markings, search_outcome, speeding, stop_hour, stop_year,
35
- subagency, vehicle_safety_and_standards, vehicletype, year):
36
- """
37
- Function to predict the 'Violation Type' of an individual sample of traffic stop:
38
- :param alcohol: boolean
39
- :param arrest_type: String
40
- :param belts: Boolean
41
- :param contributed_to_accident: Boolean
42
- :param disobedience: Boolean
43
- :param driver_city: String
44
- :param gender: 'M', 'F' or 'N'
45
- :param invalid_documentation: Boolean
46
- :param make: String
47
- :param mobile_phone: Boolean
48
- :param negligent_driving: Boolean
49
- :param number_of_offences: Integer
50
- :param personal_injury: Boolean
51
- :param property_damage: Boolean
52
- :param race: One of 'HISPANIC', 'BLACK', 'WHITE', 'OTHER', 'ASIAN', or 'NATIVE AMERICAN'
53
- :param road_signs_and_markings: Boolean
54
- :param search_outcome: String
55
- :param speeding: Boolean
56
- :param stop_hour: Integer
57
- :param stop_year: Integer
58
- :param subagency: String
59
- :param vehicle_safety_and_standards: Boolean
60
- :param vehicletype: String
61
- :param year: Integer
62
- :return:
63
- """
64
- # Create a dataframe with the feature values
65
- X = pd.DataFrame([[alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
66
- invalid_documentation, make,
67
- mobile_phone, negligent_driving, number_of_offences, personal_injury, property_damage, race,
68
- road_signs_and_markings,
69
- search_outcome, speeding, stop_hour, stop_year, subagency, vehicle_safety_and_standards,
70
- vehicletype, year]],
71
- columns=all_selected_features)
72
-
73
- # Transform the features
74
- # Encode very high cardinality features (Driver City) with ordinal encoding, by fitting on the whole dataset
75
- oev = tf["OrdinalEncoder_VeryHighCardinality"].set_params(handle_unknown='use_encoded_value').set_params(
76
- unknown_value=6714)
77
- X[very_high_cardinality_features] = oev.transform(X[very_high_cardinality_features])
78
-
79
- # Encode high cardinality features with ordinal encoding by fitting only on the training set
80
- X[high_cardinality_features] = tf["OrdinalEncoder_HighCardinality"].transform(X[high_cardinality_features])
81
-
82
- # Scale all the numerical features, including those that were ordinal encoded
83
- X[num_cols] = tf["StandardScaler"].transform(X[num_cols])
84
-
85
- # Convert booleans to numbers
86
- X[bools] = X[bools].astype('int8')
87
-
88
- # One-hot encode the low cardinality features
89
- X = tf['OneHotEncoder'].transform(X)
90
-
91
- # Make the prediction using the model
92
- prediction = model.predict(X, verbose=0)
93
-
94
- prediction = tf["LabelEncoder"].inverse_transform(np.argmax(prediction, axis=1))
95
-
96
- # Return the prediction
97
- return prediction[0]
98
-
99
- arrest_types = ['A - Marked Patrol', 'G - Marked Moving Radar (Stationary)',
100
- 'Q - Marked Laser', 'L - Motorcycle',
101
- 'H - Unmarked Moving Radar (Stationary)', 'O - Foot Patrol',
102
- 'E - Marked Stationary Radar', 'B - Unmarked Patrol',
103
- 'S - License Plate Recognition', 'R - Unmarked Laser',
104
- 'J - Unmarked Moving Radar (Moving)', 'M - Marked (Off-Duty)',
105
- 'I - Marked Moving Radar (Moving)',
106
- 'F - Unmarked Stationary Radar', 'D - Unmarked VASCAR',
107
- 'K - Aircraft Assist', 'C - Marked VASCAR', 'P - Mounted Patrol',
108
- 'N - Unmarked (Off-Duty)']
109
 
110
  iface = gr.Interface(fn=make_prediction,
111
- inputs=[gr.components.Checkbox(label='Was the driver under the influence of alcohol?'),
112
- gr.components.Dropdown(label='Choose the arrest type', choices=arrest_types, value='A - Marked Patrol')
113
- gr.components.Checkbox(label='Were seatbelts used appropriately?'),
114
- gr.components.Checkbox(label='Did the driver actions contribute to an accident?'),
 
 
115
  gr.components.Checkbox(label='Was the driver disobedient? (such as failing to display documentation upon request)?'),
116
- gr.components.Dropdown(label='Choose the arrest type', choices=tf['OrdinalEncoder_VeryHighCardinality'].categories_, value='SILVER SPRING')
117
- gr.components.Dropdown(label='Driver Gender', choices=['M', 'F', 'N'], value='M')
118
  gr.components.Checkbox(label='Was the driver driving with Invalid Documentation (such as suspended registration, suspended license, expired registration plates and validation tabs or expired license plate)?'),
119
- gr.components.Dropdown(label='Vehicle Make', choices=tf['OrdinalEncoder_HighCardinality'].categories_[0], value='TOYOTA')
120
  gr.components.Checkbox(label='Was the driver using a mobile phone while driving?'),
 
121
  gr.components.Checkbox(label='Was the driver caught driving with negligence (example switching lanes in an unsafe manner)?'),
 
122
  gr.components.Slider(minimum=1, step=1, label='Number of offences committed')],
123
- gr.components.Checkbox(label='Did the violation involve any personal injury?'),
124
- gr.components.Checkbox(label='Did the violation involve any property damage?'),
125
- gr.components.Dropdown(label='Choose the race of the driver', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[2], value='WHITE')
126
- gr.components.Checkbox(label='Did the driver fail to obey signs and markings (such as traffic control device instructions, stop lights, red signal and stop sign lines)?'),
127
- gr.components.Dropdown(label='Choose the race of the driver', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[3], value='NO SEARCH CONDUCTED')
128
- gr.components.Checkbox(label='Was the driver caught speeding?'),
129
- gr.components.Slider(maximum=23, step=1, label='Time HOUR when stop occurred in 24-hour format')],
130
- gr.components.Slider(minimum=2012, maximum=2024, step=1, label='Year when stop occurred')],
131
- gr.components.Dropdown(label='What is the name of the subagency that conducted the traffic stop?', choices=tf['OneHotEncoder'].transformers_[0][1].categories_[1], value='4th District, Wheaton')
132
- gr.components.Checkbox(label='Was the vehicle safe and up to standards (lights properly switched, registration plates attached etc.)?'),
133
- gr.components.Slider(minimum=1970, maximum=2023, step=1, label='Year of manufacture of the vehicle:)],
134
  outputs=["text"])
135
 
136
  iface.launch(debug=True)
 
1
  import gradio as gr
 
 
 
2
  import pickle
3
 
4
+ # Model URL for each
5
+ lr_model_url = './logistic regression/logistic_regression_model.pkl'
6
 
7
+ def bool_value(val):
8
+ if val:
9
+ return 1
10
+ else:
11
+ return 0
12
 
 
13
 
14
+ # Make Prediction Model but would also like to add Gender and Race
15
+ def make_prediction(personal_injury, property_damage, fatal, commercial_vehicle, alcohol, rsam, disobedience, invalid_docu, phone, speeding,
16
+ negligent, vss, num_offences):
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # load model
19
+ with open(lr_model_url, 'rb') as file:
20
+ lr_model = pickle.load(file)
21
 
22
+ x_input_feature = [[bool_value(personal_injury), bool_value(property_damage), bool_value(fatal), bool_value(commercial_vehicle), bool_value(alcohol),
23
+ bool_value(rsam), bool_value(disobedience), bool_value(invalid_docu), bool_value(phone), bool_value(speeding), bool_value(negligent),
24
+ bool_value(vss), num_offences]]
25
+ prd = lr_model.predict(x_input_feature)
26
 
27
+ if prd == 0:
28
+ return 'SERO'
29
+ elif prd == 1:
30
+ return 'Warning'
31
+ else:
32
+ return 'Citation'
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  iface = gr.Interface(fn=make_prediction,
36
+ inputs=[gr.components.Checkbox(label='Did the violation involve any personal injury?'),
37
+ gr.components.Checkbox(label='Did the violation involve any property damage?'),
38
+ gr.components.Checkbox(label='Did the violation involve any fatalities?'),
39
+ gr.components.Checkbox(label='Is the vehicle committing the traffic violation a commercial vehicle?'),
40
+ gr.components.Checkbox(label='Was the driver under the influence of alcohol?'),
41
+ gr.components.Checkbox(label='Did the driver fail to obey signs and markings (such as traffic control device instructions, stop lights, red signal and stop sign lines)?'),
42
  gr.components.Checkbox(label='Was the driver disobedient? (such as failing to display documentation upon request)?'),
 
 
43
  gr.components.Checkbox(label='Was the driver driving with Invalid Documentation (such as suspended registration, suspended license, expired registration plates and validation tabs or expired license plate)?'),
 
44
  gr.components.Checkbox(label='Was the driver using a mobile phone while driving?'),
45
+ gr.components.Checkbox(label='Was the driver caught speeding?'),
46
  gr.components.Checkbox(label='Was the driver caught driving with negligence (example switching lanes in an unsafe manner)?'),
47
+ gr.components.Checkbox(label='Was the vehicle up to standards (lights properly switched, registration plates attached etc.)?'),
48
  gr.components.Slider(minimum=1, step=1, label='Number of offences committed')],
 
 
 
 
 
 
 
 
 
 
 
49
  outputs=["text"])
50
 
51
  iface.launch(debug=True)