Spaces:
Runtime error
Runtime error
import gradio as gr | |
import plotly.graph_objects as go | |
import json | |
import requests | |
import os | |
from PIL import Image | |
import hopsworks | |
import joblib | |
import pandas as pd | |
import numpy as np | |
API_KEY = os.getenv("API-KEY-TOMTOM") | |
# Log into hopsworks | |
project = hopsworks.login() | |
fs = project.get_feature_store() | |
mr = project.get_model_registry() | |
model = mr.get_model("sthlm_incidents_model", version=4) | |
model_dir = model.download() | |
model = joblib.load(model_dir + "/sthlm_model.pkl") | |
print("Model downloaded") | |
def predict(code, magnitudeOfDelay, hour, iconCategory, latitude, longitude, month): | |
# Change the magnitude delay to an integer based on the index | |
if magnitudeOfDelay == 'Unknown': | |
magnitudeOfDelay = 0 | |
elif magnitudeOfDelay == 'Minor': | |
magnitudeOfDelay = 1 | |
elif magnitudeOfDelay == 'Moderate': | |
magnitudeOfDelay = 2 | |
elif magnitudeOfDelay == 'Major': | |
magnitudeOfDelay = 3 | |
elif magnitudeOfDelay == 'Undefined': | |
magnitudeOfDelay = 4 | |
# Change the icon category to an integer based on the index | |
if iconCategory == 'Unknown': | |
iconCategory = 0 | |
elif iconCategory == 'Accident': | |
iconCategory = 1 | |
elif iconCategory == 'Fog': | |
iconCategory = 2 | |
elif iconCategory == 'Dangerous Conditions': | |
iconCategory = 3 | |
elif iconCategory == 'Rain': | |
iconCategory = 4 | |
elif iconCategory == 'Ice': | |
iconCategory = 5 | |
elif iconCategory == 'Jam': | |
iconCategory = 6 | |
elif iconCategory == 'Lane Closed': | |
iconCategory = 7 | |
elif iconCategory == 'Road Closed': | |
iconCategory = 8 | |
elif iconCategory == 'Road Works': | |
iconCategory = 9 | |
elif iconCategory == 'Wind': | |
iconCategory = 10 | |
elif iconCategory == 'Flooding': | |
iconCategory = 11 | |
elif iconCategory == 'Broken Down Vehicle': | |
iconCategory = 14 | |
# Create a row from the input | |
row = { | |
'code': int(code), | |
'hour': int(hour), | |
'iconCategory': int(iconCategory), | |
'latitude': latitude, | |
'longitude': longitude, | |
'magnitudeOfDelay': int(magnitudeOfDelay), | |
'month': int(month) | |
} | |
# Create a df from the row | |
df_row = pd.DataFrame(row, index=[0]) | |
# make the features lower case | |
df_row.columns = df_row.columns.str.lower() | |
df_row.columns = df_row.columns.str.replace(' ', '_') | |
# Get the prediction | |
prediction = model.predict(df_row)[0] | |
return prediction | |
demo = gr.Interface( | |
fn =predict, | |
title="Stockholm Incident Prediction", | |
description="Predicts the duration of a traffic incident in Stockholm", | |
allow_flagging="never", | |
inputs=[ | |
gr.inputs.Radio(["101", "108", "115", "122", "201", "500", "701", "1101"]), | |
gr.inputs.Radio(["Unknown", "Minor", "Moderate", "Major", "Undefined"]), | |
gr.inputs.Slider(0, 23, label="Hour"), | |
gr.inputs.Radio(["Unknown", "Accident", "Fog", "Dangerous Conditions", "Rain", "Ice", "Jam", "Lane Closed", "Road Closed", "Road Works", "Wind", "Flooding", "Broken Down Vehicle"], label="Icon Category"), | |
gr.inputs.Slider(59.25, 59.40, label="Latitude"), | |
gr.inputs.Slider(18.00, 18.16, label="Longitude"), | |
gr.inputs.Slider(1, 12, label="Month") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Duration") | |
]) | |
demo.launch() |