Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from models.neural_network.inference import load_model_and_preprocessor
|
6 |
+
|
7 |
+
|
8 |
+
# Load the pre-trained model and preprocessor
|
9 |
+
nn_model, nn_preprocessor = load_model_and_preprocessor('nn_model.keras', 'nn_preprocessor.pkl')
|
10 |
+
|
11 |
+
# Load the unique aircraft data and airport distances
|
12 |
+
aircraft_data = pd.read_csv('aircraft_data.csv').drop_duplicates(subset='model')
|
13 |
+
airport_data = pd.read_csv('airport_distances.csv')
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def predict_fuel_burn(model_name, origin, destination, seats, distance):
|
18 |
+
# Validate the distance against seats
|
19 |
+
max_seats = aircraft_dict[model_name]['seats']
|
20 |
+
if seats > max_seats:
|
21 |
+
return f"The {model_name} aircraft has a maximum of {max_seats} seats."
|
22 |
+
if seats <= 0:
|
23 |
+
return "The number of seats must be greater than 0."
|
24 |
+
if distance <= 0:
|
25 |
+
return "The distance must be greater than 0."
|
26 |
+
|
27 |
+
# Prepare the input data for the model
|
28 |
+
data = {
|
29 |
+
'model': [model_name],
|
30 |
+
'Origin_Airport': [origin],
|
31 |
+
'Destination_Airport': [destination],
|
32 |
+
'seats': [seats],
|
33 |
+
'distance': [distance],
|
34 |
+
'J/T': [aircraft_dict[model_name]['J/T']],
|
35 |
+
'CAT': [aircraft_dict[model_name]['CAT']],
|
36 |
+
'_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']],
|
37 |
+
'dist': [distance]
|
38 |
+
}
|
39 |
+
|
40 |
+
df = pd.DataFrame(data)
|
41 |
+
|
42 |
+
# Make the prediction
|
43 |
+
fuel_burn_prediction_nn = nn_model.predict(nn_preprocessor.transform(df))[0]
|
44 |
+
|
45 |
+
return f" {fuel_burn_prediction_nn[0]:.2f} kg"
|
46 |
+
|
47 |
+
|
48 |
+
def update_fields(model_name):
|
49 |
+
return {
|
50 |
+
jt: gr.update(value=aircraft_dict[model_name]['J/T']),
|
51 |
+
cat: gr.update(value=aircraft_dict[model_name]['CAT']),
|
52 |
+
manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer'])
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def update_destination_options(origin):
|
57 |
+
destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique()
|
58 |
+
return gr.update(choices=list(destinations))
|
59 |
+
|
60 |
+
|
61 |
+
def update_distance(origin, destination):
|
62 |
+
distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found')
|
63 |
+
if distance_value == 'Distance not found':
|
64 |
+
return gr.update(value=0) # Return 0 if distance is not found
|
65 |
+
return gr.update(value=distance_value)
|
66 |
+
|
67 |
+
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
gr.Markdown("## Fuel Burn Prediction")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
model_name = gr.Dropdown(
|
73 |
+
label="Aircraft Model",
|
74 |
+
choices=list(aircraft_dict.keys()),
|
75 |
+
value=list(aircraft_dict.keys())[0],
|
76 |
+
)
|
77 |
+
origin = gr.Dropdown(
|
78 |
+
label="Origin Airport",
|
79 |
+
choices=sorted(airport_data['Origin_Airport'].unique())
|
80 |
+
)
|
81 |
+
destination = gr.Dropdown(
|
82 |
+
label="Destination Airport",
|
83 |
+
choices=[]
|
84 |
+
)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
jt = gr.Textbox(label="J/T", interactive=False)
|
88 |
+
cat = gr.Textbox(label="CAT", interactive=False)
|
89 |
+
manufacturer = gr.Textbox(label="Manufacturer", interactive=False)
|
90 |
+
seats = gr.Number(label="Seats")
|
91 |
+
|
92 |
+
distance = gr.Number(label="Distance", interactive=False)
|
93 |
+
|
94 |
+
model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer])
|
95 |
+
origin.change(fn=update_destination_options, inputs=origin, outputs=destination)
|
96 |
+
destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance)
|
97 |
+
|
98 |
+
submit_btn = gr.Button("Predict Fuel Burn")
|
99 |
+
result = gr.Textbox(label="Fuel Burn in Kg", interactive=False)
|
100 |
+
|
101 |
+
submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result)
|
102 |
+
|
103 |
+
demo.launch()
|
104 |
+
|