JoBeer's picture
Update app.py
61c386c verified
raw
history blame
6.5 kB
#Importierung Pakete
import gradio as gr
import pandas as pd #Verarbeitung excel/csv
from darts.models import LinearRegressionModel #Importierung Klasse LinearRegressionModel von darts
from darts import TimeSeries #In darts wird für Datensätze das TimeSeries Format verwendet
from sklearn.metrics import mean_absolute_error, r2_score
import os
import plotly.express as px #Erstellung von Diagrammen (dynamisch --> bspw. zoom)
# Funktion, um die hochgeladene Datei zu lesen
def upload_excel(file):
if file is not None:
# Excel-Datei lesen
df = pd.read_excel(file.name, index_col='time')
return df
# Funktion, um Checkbox-Optionen automatisch zu erstellen
def create_feature_checkbox(df):
if df is not None:
new_choices = list(df.columns)
return gr.update(choices=new_choices)
else:
return gr.update(choices=[])
# Funktion, um Trainingsstartdatum automatisch zu laden
def fill_start_date(df):
if df is not None:
start_date = str(df.index.min())
return gr.update(value=start_date)
else:
return gr.update(value='')
# Funktion, um Teststartdatum automatisch zu laden
def fill_end_date(df):
if df is not None:
end_date = str(df.index.max())
return gr.update(value=end_date)
else:
return gr.update(value='')
#Funktion, läd die excel Datei hoch und führt außerdem die Funktionen create_feature_checkbox, fill_start_date, fill_end_date
def process_excel(file):
df = upload_excel(file)
return df, create_feature_checkbox(df), create_feature_checkbox(df), fill_start_date(df), fill_end_date(df) #create_feature_checkbox steht zwei mal drinnen wegen target und covariate
#Funktion training und training/test-split
def data_split_and_training(df,start_train,end_train,start_test,end_test,target,covariates,lags_target,lags_covariates):
#Aufteilen in Train Test
train_df = df[start_train:end_train]
test_df = df[start_test:end_test]
#Umwandeln in darts.TimeSeries
train = TimeSeries.from_dataframe(train_df)
test = TimeSeries.from_dataframe(test_df)
#Lags Eingabe in Liste umwandeln für Darts Bibliothek
lags_target_list = list(range(-1,-lags_target-1,-1))
lags_covariates_list = list(range(-1,-lags_covariates-1,-1))
#model instantiation
model = LinearRegressionModel(lags=lags_target_list, lags_future_covariates=lags_covariates_list)
#model training
model_fit = model.fit(train[target], future_covariates= train[covariates])
return train, test, model_fit, f'Training is complete'
# Funktion zur Speicherung des trainierten Modells
def create_file(state_value):
file_path = "model_fit.pt"
with open(file_path, "w") as f:
f.write(f"State Value: {state_value}")
return file_path
# Funktion für Prediction
def test(model_fit,test,train,covariates):
#prediction
prediction = model_fit.predict(n=len(test), future_covariates= train.append(test)[covariates])
#converting to pandas
df_prediction = prediction.pd_dataframe()
# Visualisierung mit Plotly
fig = px.line()
fig.add_scatter(x=df_prediction.index, y=test.pd_dataframe()[df_prediction.columns[0]], mode='lines', name='test', line_color='red')
fig.add_scatter(x=df_prediction.index, y=df_prediction[df_prediction.columns[0]], mode='lines', name='predict', line_color='blue')
fig.update_yaxes(title_text=df_prediction.columns[0])
#check metrics
target_variable = []
MAE = []
R2 = []
for i in df_prediction.columns:
target_variable.append(i)
MAE.append(mean_absolute_error(df_prediction[i], test.pd_dataframe()[i]))
R2.append(r2_score(df_prediction[i], test.pd_dataframe()[i]))
metrics = pd.DataFrame({'target':target_variable,'MAE': MAE, 'R2': R2})
return metrics, fig
# Gradio-Interface erstellen
with gr.Blocks() as demo:
gr.Markdown("## ML-based Building Simulation") #Überschrift
# Datei-Upload-Komponente
file_input = gr.File(label="Drop an excel file", file_types=[".xls", ".xlsx"]) #Hochladebereich
#Upload Button
upload_button = gr.Button("Upload Excel") #Button für Funktion proceed_excel()
#Feature checkboxen erstellen
checkbox_group_target = gr.CheckboxGroup(label="Select target variable", choices=[], interactive=True) #Auswahl target variable
checkbox_group_covariate = gr.CheckboxGroup(label="Select exogenous variables", choices=[], interactive=True) #Auswahl covariate variables
with gr.Row():
# Eingabe Teilung Training und Testdaten
start_train = gr.Textbox(label='start_train:', lines=1, interactive=True)
end_train = gr.Textbox(label='end_train:', lines=1, interactive=True, value = '2019-06-04 02:15:00')
start_test = gr.Textbox(label='start_test:', lines=1, interactive=True, value = '2019-06-04 02:30:00')
end_test = gr.Textbox(label='end_test:', lines=1, interactive=True)
# Eingabe Verzögerung
with gr.Row():
lags_target = gr.Number(value=1, label='Number of laged values for target', interactive=True)
lags_covariates = gr.Number(value=1, label='Number of laged values for exogenous variables', interactive=True)
# DataFrame df wird als Zustand gespeichert, damit er in in anderen Funktionen weiter verwendet werden kann
df_state = gr.State(None)
train_state = gr.State(None)
test_state = gr.State(None)
model_fit_state = gr.State(None)
#Training Button
training_button = gr.Button("Training") #Button für Funktion data_split_and_training()
with gr.Row():
#Training Status
status_output = gr.Textbox(label="Training status")
with gr.Column():
#Download Button
download_button = gr.DownloadButton("Download model")
test_button = gr.Button("Test")
# Tabellenausgabe Metrik
metrics_table = gr.DataFrame(label="Metrics", headers=['target','MAE','R2'])
# Plotausgabe Vorhersagen
plot = gr.Plot()
# Event-Verknüpfungen
upload_button.click(process_excel, inputs=file_input, outputs=[df_state, checkbox_group_target, checkbox_group_covariate, start_train, end_test])
training_button.click(data_split_and_training,
inputs=[df_state, start_train, end_train, start_test, end_test, checkbox_group_target, checkbox_group_covariate, lags_target, lags_covariates],
outputs=[train_state, test_state, model_fit_state, status_output])
download_button.click(create_file,inputs=[model_fit_state],outputs=[download_button])
test_button.click(test,inputs=[model_fit_state,test_state,train_state,checkbox_group_covariate],outputs=[metrics_table,plot])
# Demo starten
demo.launch(show_error=True)