# Import TensorFlow environment settings import setup_environment import gradio as gr import pandas as pd import numpy as np # Import the function from load_model.py from load_model import load_model_and_preprocessor # Load the pre-trained model and preprocessor nn_model, nn_preprocessor = load_model_and_preprocessor('nn_model.keras', 'nn_preprocessor.pkl') # Load the unique aircraft data and airport distances aircraft_data = pd.read_csv('aircraft_data.csv').drop_duplicates(subset='model') aircraft_dict = aircraft_data.set_index('model').to_dict(orient='index') airport_data = pd.read_csv('airport_distances.csv') airport_dict = airport_data.set_index(['Origin_Airport', 'Destination_Airport']).to_dict(orient='index') def predict_fuel_burn(model_name, origin, destination, seats, distance): # Validate the distance against seats max_seats = aircraft_dict[model_name]['seats'] if seats > max_seats: return f"The {model_name} aircraft has a maximum of {max_seats} seats." if seats <= 0: return "The number of seats must be greater than 0." if distance <= 0: return "The distance must be greater than 0." # Prepare the input data for the model data = { 'model': [model_name], 'Origin_Airport': [origin], 'Destination_Airport': [destination], 'seats': [seats], 'distance': [distance], 'J/T': [aircraft_dict[model_name]['J/T']], 'CAT': [aircraft_dict[model_name]['CAT']], '_Manufacturer': [aircraft_dict[model_name]['_Manufacturer']], 'dist': [distance] } df = pd.DataFrame(data) # Make the prediction fuel_burn_prediction_nn = nn_model.predict(nn_preprocessor.transform(df))[0] return f" {fuel_burn_prediction_nn[0]:.2f} kg" def update_fields(model_name): return { jt: gr.update(value=aircraft_dict[model_name]['J/T']), cat: gr.update(value=aircraft_dict[model_name]['CAT']), manufacturer: gr.update(value=aircraft_dict[model_name]['_Manufacturer']) } def update_destination_options(origin): destinations = airport_data[airport_data['Origin_Airport'] == origin]['Destination_Airport'].unique() return gr.update(choices=list(destinations)) def update_distance(origin, destination): distance_value = airport_dict.get((origin, destination), {}).get('distance', 'Distance not found') if distance_value == 'Distance not found': return gr.update(value=0) # Return 0 if distance is not found return gr.update(value=distance_value) with gr.Blocks() as demo: gr.Markdown("## Fuel Burn Prediction") with gr.Row(): model_name = gr.Dropdown( label="Aircraft Model", choices=list(aircraft_dict.keys()), value=list(aircraft_dict.keys())[0], ) origin = gr.Dropdown( label="Origin Airport", choices=sorted(airport_data['Origin_Airport'].unique()) ) destination = gr.Dropdown( label="Destination Airport", choices=[] ) with gr.Row(): jt = gr.Textbox(label="J/T", interactive=False) cat = gr.Textbox(label="CAT", interactive=False) manufacturer = gr.Textbox(label="Manufacturer", interactive=False) seats = gr.Number(label="Seats") distance = gr.Number(label="Distance", interactive=False) model_name.change(fn=update_fields, inputs=model_name, outputs=[jt, cat, manufacturer]) origin.change(fn=update_destination_options, inputs=origin, outputs=destination) destination.change(fn=update_distance, inputs=[origin, destination], outputs=distance) submit_btn = gr.Button("Predict Fuel Burn") result = gr.Textbox(label="Fuel Burn in Kg", interactive=False) submit_btn.click(predict_fuel_burn, inputs=[model_name, origin, destination, seats, distance], outputs=result) demo.launch()