File size: 3,360 Bytes
3829ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ad9eb4
3829ec7
9f417d1
3829ec7
 
9178d8d
 
be959b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3829ec7
 
9178d8d
be959b3
 
3829ec7
 
f361f28
9178d8d
3829ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a8dace
be959b3
3829ec7
be959b3
3829ec7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import plotly.graph_objects as go
import json
import requests
import os
from PIL import Image
import hopsworks
import joblib
import pandas as pd
import numpy as np

API_KEY = os.getenv("API-KEY-TOMTOM")

# Log into hopsworks
project = hopsworks.login()
fs = project.get_feature_store()
mr = project.get_model_registry()

model = mr.get_model("sthlm_incidents_model", version=4)
model_dir = model.download()
model = joblib.load(model_dir + "/sthlm_model.pkl")
print("Model downloaded")

def predict(code, magnitudeOfDelay, hour, iconCategory, latitude, longitude, month):
    
    # Change the magnitude delay to an integer based on the index
    if magnitudeOfDelay == 'Unknown':
        magnitudeOfDelay = 0
    elif magnitudeOfDelay == 'Minor':
        magnitudeOfDelay = 1
    elif magnitudeOfDelay == 'Moderate':
        magnitudeOfDelay = 2
    elif magnitudeOfDelay == 'Major':
        magnitudeOfDelay = 3
    elif magnitudeOfDelay == 'Undefined':
        magnitudeOfDelay = 4

    # Change the icon category to an integer based on the index
    if iconCategory == 'Unknown':
        iconCategory = 0
    elif iconCategory == 'Accident':
        iconCategory = 1
    elif iconCategory == 'Fog':
        iconCategory = 2
    elif iconCategory == 'Dangerous Conditions':
        iconCategory = 3
    elif iconCategory == 'Rain':
        iconCategory = 4
    elif iconCategory == 'Ice':
        iconCategory = 5
    elif iconCategory == 'Jam':
        iconCategory = 6
    elif iconCategory == 'Lane Closed':
        iconCategory = 7
    elif iconCategory == 'Road Closed':
        iconCategory = 8
    elif iconCategory == 'Road Works':
        iconCategory = 9
    elif iconCategory == 'Wind':
        iconCategory = 10
    elif iconCategory == 'Flooding':
        iconCategory = 11
    elif iconCategory == 'Broken Down Vehicle':
        iconCategory = 14

    # Create a row from the input
    row = {
        'code': int(code),
        'hour': int(hour),
        'iconCategory': int(iconCategory),
        'latitude': latitude,
        'longitude': longitude,
        'magnitudeOfDelay': int(magnitudeOfDelay),
        'month': int(month)
    }

    # Create a df from the row
    df_row = pd.DataFrame(row, index=[0])

    # make the features lower case
    df_row.columns = df_row.columns.str.lower()
    df_row.columns = df_row.columns.str.replace(' ', '_')

    # Get the prediction
    prediction = model.predict(df_row)[0]
    
    return prediction

demo = gr.Interface(
    fn =predict,
    title="Stockholm Incident Prediction",
    description="Predicts the duration of a traffic incident in Stockholm",
    allow_flagging="never",
    inputs=[
        gr.inputs.Radio(["101", "108", "115", "122", "201", "500", "701", "1101"]),
        gr.inputs.Radio(["Unknown", "Minor", "Moderate", "Major", "Undefined"]),
        gr.inputs.Slider(0, 23, label="Hour"),
        gr.inputs.Radio(["Unknown", "Accident", "Fog", "Dangerous Conditions", "Rain", "Ice", "Jam", "Lane Closed", "Road Closed", "Road Works", "Wind", "Flooding", "Broken Down Vehicle"], label="Icon Category"),
        gr.inputs.Slider(59.25, 59.40, label="Latitude"),
        gr.inputs.Slider(18.00, 18.16, label="Longitude"),
        gr.inputs.Slider(1, 12, label="Month")
    ],
    outputs=[
        gr.outputs.Textbox(label="Duration")
    ])

demo.launch()