import gradio as gr import plotly.graph_objects as go import json import requests import os from PIL import Image import hopsworks import joblib import pandas as pd import numpy as np API_KEY = os.getenv("API-KEY-TOMTOM") # Log into hopsworks project = hopsworks.login() fs = project.get_feature_store() mr = project.get_model_registry() model = mr.get_model("sthlm_incidents_model", version=4) model_dir = model.download() model = joblib.load(model_dir + "/sthlm_model.pkl") print("Model downloaded") def predict(code, magnitudeOfDelay, hour, iconCategory, latitude, longitude, month): # Change the magnitude delay to an integer based on the index if magnitudeOfDelay == 'Unknown': magnitudeOfDelay = 0 elif magnitudeOfDelay == 'Minor': magnitudeOfDelay = 1 elif magnitudeOfDelay == 'Moderate': magnitudeOfDelay = 2 elif magnitudeOfDelay == 'Major': magnitudeOfDelay = 3 elif magnitudeOfDelay == 'Undefined': magnitudeOfDelay = 4 # Change the icon category to an integer based on the index if iconCategory == 'Unknown': iconCategory = 0 elif iconCategory == 'Accident': iconCategory = 1 elif iconCategory == 'Fog': iconCategory = 2 elif iconCategory == 'Dangerous Conditions': iconCategory = 3 elif iconCategory == 'Rain': iconCategory = 4 elif iconCategory == 'Ice': iconCategory = 5 elif iconCategory == 'Jam': iconCategory = 6 elif iconCategory == 'Lane Closed': iconCategory = 7 elif iconCategory == 'Road Closed': iconCategory = 8 elif iconCategory == 'Road Works': iconCategory = 9 elif iconCategory == 'Wind': iconCategory = 10 elif iconCategory == 'Flooding': iconCategory = 11 elif iconCategory == 'Broken Down Vehicle': iconCategory = 14 # Create a row from the input row = { 'code': int(code), 'hour': int(hour), 'iconCategory': int(iconCategory), 'latitude': latitude, 'longitude': longitude, 'magnitudeOfDelay': int(magnitudeOfDelay), 'month': int(month) } # Create a df from the row df_row = pd.DataFrame(row, index=[0]) # make the features lower case df_row.columns = df_row.columns.str.lower() df_row.columns = df_row.columns.str.replace(' ', '_') # Get the prediction prediction = model.predict(df_row)[0] return prediction demo = gr.Interface( fn =predict, title="Stockholm Incident Prediction", description="Predicts the duration of a traffic incident in Stockholm", allow_flagging="never", inputs=[ gr.inputs.Radio(["101", "108", "115", "122", "201", "500", "701", "1101"]), gr.inputs.Radio(["Unknown", "Minor", "Moderate", "Major", "Undefined"]), gr.inputs.Slider(0, 23, label="Hour"), gr.inputs.Radio(["Unknown", "Accident", "Fog", "Dangerous Conditions", "Rain", "Ice", "Jam", "Lane Closed", "Road Closed", "Road Works", "Wind", "Flooding", "Broken Down Vehicle"], label="Icon Category"), gr.inputs.Slider(59.25, 59.40, label="Latitude"), gr.inputs.Slider(18.00, 18.16, label="Longitude"), gr.inputs.Slider(1, 12, label="Month") ], outputs=[ gr.outputs.Textbox(label="Duration") ]) demo.launch()