import os |
import re |
import uuid |
import time |
import shutil |
import zipfile |
import threading |
import subprocess |
import select |
from datetime import datetime |
from concurrent.futures import ThreadPoolExecutor |
import dash |
from dash import dcc, html |
from dash.dependencies import Input, Output, State, ALL |
import dash_bootstrap_components as dbc |
from dash.exceptions import PreventUpdate |
from flask import Flask, render_template, request, send_file, jsonify |
import yaml |
import ruamel.yaml |
import pandas as pd |
import logging |
logging.basicConfig(level=logging.DEBUG) |
logger = logging.getLogger(__name__) |
server = Flask(__name__) |
server.secret_key = os.urandom(24) |
@server.route('/') |
def welcome_page(): |
""" |
Handles the welcome page route. |
This function extracts the username from the request host, |
determines if the duplicate mode should be enabled, and renders |
the welcome page template with the duplicate mode state. |
Returns: |
str: The rendered 'index.html' template with the duplicate_mode parameter. |
""" |
host = request.host |
print("host:", host) |
usr_match = re.match(r'^(.*?)\-stm32', host) |
print("usr_match:", usr_match) |
if usr_match: |
hf_user = usr_match.group(1) |
else: |
hf_user = "modelzoo_user" |
if hf_user == "stmicroelectronics": |
duplicate_mode = True |
else: |
duplicate_mode = False |
print("hf_user:", hf_user) |
print("duplicate_mode:", duplicate_mode) |
return render_template('index.html', duplicate_mode=duplicate_mode) |
external_stylesheets = [dbc.themes.LITERA] |
app = dash.Dash(__name__, server=server,external_stylesheets=external_stylesheets, url_base_pathname='/dash_app/', suppress_callback_exceptions=True) |
local_yamls = { |
'image_classification': 'stm32ai-modelzoo-services/image_classification/src/user_config.yaml', |
'human_activity_recognition': 'stm32ai-modelzoo-services/human_activity_recognition/src/user_config.yaml', |
'hand_posture': 'stm32ai-modelzoo-services/hand_posture/src/user_config.yaml', |
'object_detection': 'stm32ai-modelzoo-services/object_detection/src/user_config.yaml', |
'audio_event_detection': 'stm32ai-modelzoo-services/audio_event_detection/src/user_config.yaml', |
'pose_estimation': 'stm32ai-modelzoo-services/pose_estimation/src/user_config.yaml', |
'semantic_segmentation': 'stm32ai-modelzoo-services/semantic_segmentation/src/user_config.yaml' |
} |
def banner(): |
return html.Div( |
id="banner", |
className="top-bar", |
style={ |
"display": "flex", |
"align-items": "center", |
"justify-content": "space-between", |
"position": "fixed", |
"top": "0", |
"left": "0", |
"width": "100%", |
"z-index": "1000", |
"background-color": "#3234b", |
"padding": "10px 20px", |
"box-shadow": "0px 2px 4px rgba(0, 0, 0, 0.1)" |
}, |
children=[ |
html.A( |
id="learn-more-button", |
children=[ |
html.Img( |
src=app.get_asset_url("github-mark-white.png"), |
style={"width": "20px", "height": "20px", "margin-right": "10px"} |
), |
"stm32ai-modelzoo", |
], |
href="https://github.com/STMicroelectronics/stm32ai-modelzoo-services", |
target="_blank", |
style={ |
"display": "flex", |
"align-items": "center", |
"color": "#ffffff", |
"text-decoration": "none", |
"font-size": "15px", |
"font-family": "Arial, sans-serif" |
} |
), |
html.Div( |
html.Img( |
id="logo", |
src=app.get_asset_url("ST_logo_2024_white.png"), |
style={"width": "50px", "height": "auto"} |
), |
style={"text-align": "center"} |
), |
html.Div( |
[ |
html.A( |
[ |
html.H5( |
"ST Edge AI Developer Cloud", |
style={ |
"margin": "0", |
"text-align": "right", |
"color": "#ffffff", |
"font-size": "15px", |
"font-weight": "bold", |
"font-family": "Arial, sans-serif" |
} |
) |
], |
href="https://stm32ai-cs.st.com/home", |
target="_blank", |
style={ |
"display": "flex", |
"align-items": "center", |
"text-decoration": "none" |
} |
) |
], |
style={"padding-right": "10px"} |
) |
] |
) |
def read_configs(selected_model): |
""" |
Loads a YAML file based on the selected model by the user. |
Args: |
selected_model (str): The key to select the appropriate YAML file path. |
Returns: |
dict: The loaded YAML data. |
""" |
if not selected_model: |
raise ValueError("No model selected. Please select a valid model.") |
if selected_model not in local_yamls: |
raise ValueError(f"Model '{selected_model}' not found in local_yamls") |
yaml_path = local_yamls[selected_model] |
try: |
with open(yaml_path, 'r') as file: |
return yaml.safe_load(file) |
except Exception as e: |
raise ValueError(f"Error reading YAML file at {yaml_path}: {e}") |
def build_yaml_form(yaml_content, parent_key=''): |
""" |
Recursively builds a form based on the provided YAML content. |
Parameters: |
- yaml_content (dict): The YAML content to build the form from. |
- parent_key (str): The parent key to maintain the hierarchy of nested keys. Default is an empty string. |
Returns: |
- list: A list of Dash Bootstrap Components (dbc) AccordionItems representing the form fields. |
""" |
accordion_items = [] |
for key, value in yaml_content.items(): |
full_key = f"{parent_key}.{key}" if parent_key else key |
if isinstance(value, dict): |
nested_accordion = build_yaml_form(value, full_key) |
accordion_items.append( |
dbc.AccordionItem( |
nested_accordion, |
title=key.capitalize() |
) |
) |
else: |
field = [html.Label(key, style={"font-weight": "bold", "margin-bottom": "5px"})] |
if isinstance(value, bool): |
field.append( |
dcc.Checklist( |
id={'type': 'yaml-setting', 'index': full_key}, |
options=[{'label': '', 'value': True}], |
value=[True] if value else [], |
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
) |
) |
elif isinstance(value, list): |
field.append( |
dcc.Dropdown( |
id={'type': 'yaml-setting', 'index': full_key}, |
options=[{'label': str(v), 'value': v} for v in value], |
value=value, |
multi=True, |
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
) |
) |
else: |
field.append( |
dcc.Input( |
id={'type': 'yaml-setting', 'index': full_key}, |
value=value, |
type='text', |
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
) |
) |
accordion_items.append( |
dbc.AccordionItem( |
field, |
title=key.capitalize() |
) |
) |
return accordion_items |
def create_yaml(yaml_content): |
""" |
Creates a YAML form using Dash Bootstrap Components (dbc) and Dash HTML Components (html). |
Parameters: |
yaml_content (dict): The content of the YAML file to be used for building the form. |
Returns: |
dbc.Form: A Dash form component containing an accordion with the YAML content and a submit button. |
""" |
accordion_items = build_yaml_form(yaml_content) |
accordion = dbc.Accordion( |
accordion_items, |
start_collapsed=True |
) |
return dbc.Form([ |
accordion, |
html.Div( |
dbc.Button( |
'Submit', |
id='apply-button', |
style={ |
'background-color': '#FFD200', |
'color': '#03234b', |
'font-size': '14px', |
'padding': '10px 10px 10px 10px', |
'border-radius': '5px', |
'margin-top': '15px', |
'border': '2px solid #FFD200', |
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
} |
), |
style={ |
'display': 'flex', |
'justify-content': 'center', |
'margin-top': '15px', |
} |
), |
html.Div( |
id='submission-outcome', |
style={ |
'marginTop': '10px', |
'textAlign': 'center', |
'fontStyle': 'italic', |
'color': '#03234b', |
'font-size': '14px' |
} |
) |
]) |
def process_form_configs(form_configs): |
""" |
Extracts and processes form data to update YAML content. |
This function processes the form data, converting values to appropriate types |
and updating the YAML content accordingly. |
Args: |
form_configs (dict): The form data to be processed. |
Returns: |
dict: The updated YAML content with processed form data. |
""" |
updated_yaml = {} |
for key, value in form_configs.items(): |
if value is not None: |
if isinstance(value, list) and len(value) == 1: |
value = value[0] |
if isinstance(value, str): |
try: |
if '.' in value: |
value = float(value) |
else: |
value = int(value) |
except ValueError: |
pass |
updated_yaml[key] = value |
return updated_yaml |
def create_archive(archive_path, directory_to_compress): |
""" |
Creates a ZIP archive of a specified directory. |
Parameters: |
archive_path (str): The path where the ZIP archive will be created. |
directory_to_compress (str): The directory whose contents will be compressed into the ZIP archive. |
Returns: |
None |
""" |
def add_file_to_zip(zipf, file_path, arcname): |
""" |
Adds a file to the ZIP archive. |
Parameters: |
zipf (zipfile.ZipFile): The ZIP file object. |
file_path (str): The path of the file to add to the ZIP archive. |
arcname (str): The archive name for the file within the ZIP archive. |
Returns: |
None |
""" |
zipf.write(file_path, arcname=arcname) |
with zipfile.ZipFile(archive_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: |
with ThreadPoolExecutor() as executor: |
for root_dir, sub_dirs, files in os.walk(directory_to_compress): |
for file_name in files: |
file_path = os.path.join(root_dir, file_name) |
if os.path.abspath(file_path) != os.path.abspath(archive_path): |
arcname = os.path.relpath(file_path, directory_to_compress) |
executor.submit(add_file_to_zip, zipf, file_path, arcname) |
def create_dashboard_layout(): |
""" |
Creates the layout for the application: STM32ModelZoo dashboard. |
This function defines the structure and components of the dashboard, |
including the banner, model selection dropdown, YAML update options, |
credentials input, output display, training metrics graphs, and download button. |
Returns: |
dbc.Container: A Dash Bootstrap Component container with the dashboard layout. |
""" |
return html.Div([ |
banner(), |
dbc.Container([ |
dcc.Location(id='url', refresh=False), |
dbc.Row(dbc.Col(html.H3("STM32 Modelzoo", style={'color': '#03234b', 'text-align': 'center',"margin-top": "80px", "font-family": "Arial, sans-serif"}), className="mb-4")), |
dbc.Row([ |
dbc.Col( |
html.H5("Use case selection", style={'color': '#03234b', 'margin-bottom': '10px'}), |
width=12 |
) |
], id="use-case-section", style={"display": "none"}), |
dbc.Row(dbc.Col(dcc.Dropdown( |
id='selected-model', |
options=[ |
{'label': 'Image Classification (IC)', 'value': 'image_classification'}, |
{'label': 'Human Activity Recognition (HAR)', 'value': 'human_activity_recognition'}, |
{'label': 'Hand Posture', 'value': 'hand_posture'}, |
{'label': 'Audio Event Detection(AED)', 'value': 'audio_event_detection'}, |
{'label': 'Object Detection', 'value': 'object_detection'}, |
{'label': 'Pose estimation', 'value': 'pose_estimation'}, |
{'label': 'Semantic Segmentation', 'value': 'semantic_segmentation'}, |
], |
placeholder="Please select your use case", |
className="mb-4" |
))), |
dbc.Row( |
dbc.Col( |
html.Div( |
id='toggle-yaml', |
children=[ |
html.P([ |
"Please update the YAML file: Dataset path (example: ../datasets/your_use_case/name_of_dataset) or datasets/your_prepared_dataset. For more details, refer to the ", |
html.A("README", href="https://huggingface.co/spaces/STMicroelectronics/stm32-modelzoo-app/blob/main/datasets/README.md", target="_blank", style={'color': '#007bff', 'text-decoration': 'underline'}), |
"." |
], style={'font-family': 'Arial, sans-serif', 'color': '#03234b', 'fontSize': '15px'}), |
dcc.RadioItems( |
id='modify-yaml-choice', |
labelStyle={'display': 'inline-block', 'margin-right': '10px'}, |
className="mb-4", |
), |
dcc.Upload( |
id='load-yaml-file', |
children=html.Button('Upload YAML File'), |
style={'display': 'none'} |
), |
html.Div(id='load-state', style={'margin-top': '10px'}), |
html.Div(id='yaml-layout', style={'display': 'none'}) |
], |
style={'font-family': 'Arial, sans-serif', 'display': 'none'} |
) |
) |
), |
dbc.Row([ |
dbc.Col([ |
html.P("Enter your ST Edge AI Developer Cloud credentials:", style={'color': '03234b', 'fontSize': '15px', 'fontWeight': 'bold'}, className="credentials-text"), |
dcc.Input(id='devcloud-username-input', type='text', placeholder='Enter username', className="input-field mb-2"), |
dcc.Input(id='devcloud-password-input', type='password', placeholder='Enter password', className="input-field mb-4") |
], width=6), |
dbc.Col([ |
dbc.Button('Launch training', id='process-button', color="#3234b", className="start-button mb-4", style={'display': 'none', 'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)'}) |
], className="credentials-col") |
], id='credentials-section', style={ |
'display': 'none', |
'justify-content': 'center', |
'align-items': 'center', |
'height': '100vh', |
}, className="credentials-section mb-4"), |
dbc.Row([ |
dbc.Col( |
html.H5("Results visualization", style={'color': '#03234b', 'margin-bottom': '10px'}), |
width=12 |
) |
], id="results-section", style={"display": "none"}), |
dbc.Row([ |
dbc.Col(dbc.Card([ |
dbc.CardHeader("Command output"), |
dbc.CardBody( |
html.Div(id='log-reader', style={'whiteSpace': 'pre-wrap', 'padding-top': '15px', 'height': '100%', 'overflow': 'auto'}), |
style={'height': '300px'} |
) |
])) |
],style={'margin-bottom': '30px'}), |
dbc.Row([ |
dbc.Col(dbc.Card([ |
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
dbc.CardBody( |
dcc.Graph(id='acc-visualization', style={'height': '100%', 'width': '100%'}), |
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
) |
]), width=6, style={'padding': '10px'}), |
dbc.Col(dbc.Card([ |
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
dbc.CardBody( |
dcc.Graph(id='loss-visualization', style={'height': '100%', 'width': '100%'}), |
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
) |
]), width=6, style={'padding': '10px'}) |
], style={'margin-bottom': '30px'}), |
dcc.Interval(id='interval-widget', interval=1000, n_intervals=0), |
dcc.Download(id="download-resource"), |
dbc.Row( |
dbc.Col( |
dbc.Button('Download outputs', id='download-action', className="mb-4", style={ |
'background-color': '#ffd200', |
'color': '#ffffff', |
'font-size': '14px', |
'padding': '10px 10px', |
'border-radius': '5px', |
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
'margin-top': '20px' |
}), |
style={ |
'display': 'flex', |
'justify-content': 'center', |
'alignItems': 'center', |
} |
) |
) |
], fluid=True) |
]) |
app.layout = create_dashboard_layout |
logs = [] |
lock = threading.Lock() |
new_training = False |
def fill_logs(message): |
""" |
Appends a message to the logs list in a thread-safe manner. |
Parameters: |
message (str): The message to be appended to the logs. |
Returns: |
None |
""" |
with lock: |
logs.append(message) |
def run_script(script, devcloud_username, devcloud_password): |
""" |
Executes a given script with the provided ST Developer Cloud credentials and logs the output. |
Parameters: |
- script (str): The path to the script to be executed. |
- devcloud_username (str): Username for ST Developer Cloud. |
- devcloud_password (str): Password for ST Developer Cloud. |
Returns: |
- None |
""" |
global logs |
with lock: |
logs = [] |
os.environ['stmai_username'] = devcloud_username |
os.environ['stmai_password'] = devcloud_password |
os.environ['STATS_TYPE'] = 'HuggingFace_devcloud' |
execution = subprocess.Popen(['python3', script], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
while True: |
file_descriptors = [execution.stdout.fileno(), execution.stderr.fileno()] |
selected_descriptors = select.select(file_descriptors, [], []) |
for descriptor in selected_descriptors[0]: |
if descriptor == execution.stdout.fileno(): |
out = execution.stdout.readline() |
if out: |
fill_logs(out) |
if out == '' and execution.poll() is not None: |
return |
if descriptor == execution.stderr.fileno(): |
error = execution.stderr.readline() |
if error: |
fill_logs(error) |
def execute_async(script, devcloud_username, devcloud_password): |
""" |
Executes a Python script asynchronously in a separate thread. |
Parameters: |
script (str): The path to the Python script to be executed. |
devcloud_username (str): The username for the DevCloud environment. |
devcloud_password (str): The password for the DevCloud environment. |
Returns: |
None |
""" |
thread = threading.Thread(target=run_script, args=(script, devcloud_username, devcloud_password)) |
thread.start() |
@app.callback( |
Output("config-section", "style"), |
Input('selected-model', 'value') |
) |
def toggle_config_section(selected_model): |
""" |
Toggles the visibility of the configuration section based on the selected model. |
Parameters: |
selected_model (str): The value of the selected model from the dropdown. |
Returns: |
dict: A dictionary containing the CSS style for the configuration section. |
""" |
if selected_model: |
return {"display": "block"} |
else: |
return {"display": "none"} |
@app.callback( |
Output('toggle-yaml', 'style'), |
Input('selected-model', 'value') |
) |
def dipslay_yaml_container(selected_model): |
""" |
Toggles the display of the YAML update container based on the selected model. |
This function updates the CSS style of the YAML update container to either |
show or hide it based on whether a model is selected from the dropdown. |
Args: |
selected_model (str): The selected model from the dropdown. |
Returns: |
dict: A dictionary containing the CSS style to either display or hide the container. |
""" |
if selected_model: |
return {'display': 'block'} |
return {'display': 'none'} |
@app.callback( |
[Output('yaml-layout', 'style'), |
Output('yaml-layout', 'children')], |
[Input('modify-yaml-choice', 'value'), |
Input('selected-model', 'value')] |
) |
def display_yaml_form(selection_update, selected_model): |
""" |
Toggles the display of the YAML form and updates its content based on user input. |
This function updates the CSS style and content of the YAML form based on whether |
the user chooses to update the YAML file and a model is selected from the dropdown. |
Args: |
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
selected_model (str): The selected model from the dropdown. |
Returns: |
tuple: A tuple containing the CSS style to either display or hide the form, |
and the form content generated from the YAML data. |
""" |
if not selected_model: |
return {'display': 'none'}, "Please select a model to display its configuration." |
try: |
yaml_conf = read_configs(selected_model) |
form_conf = create_yaml(yaml_conf) |
return {'display': 'block'}, form_conf |
except ValueError as e: |
return {'display': 'none'}, f"Error: {str(e)}" |
except Exception as e: |
return {'display': 'none'}, f"Unexpected Error: {str(e)}" |
@app.callback( |
Output('credentials-section', 'style'), |
[Input('modify-yaml-choice', 'value'), |
Input('selected-model', 'value'), |
Input('apply-button', 'n_clicks')] |
) |
def display_credentials(selection_update, selected_model, n_clicks): |
""" |
Toggles the display of the credentials input fields based on user input. |
This function updates the CSS style of the credentials input fields to either |
show or hide them based on the user's choice to update the YAML file and the |
selection of a model from the dropdown. |
Args: |
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
selected_model (str): The selected model from the dropdown. |
Returns: |
dict: A dictionary containing the CSS style to either display or hide the credentials input fields. |
""" |
if n_clicks is None or n_clicks == 0: |
return {'display': 'none'} |
return {'display': 'block'} |
@app.callback( |
Output('process-button', 'style'), |
[Input('apply-button', 'n_clicks')] |
) |
def display_launch_training(n_clicks): |
""" |
Displays the process button based on the number of clicks on the apply button. |
Parameters: |
n_clicks (int): The number of times the apply button has been clicked. |
Returns: |
dict: A dictionary containing the CSS style for the process button. |
""" |
if n_clicks and n_clicks > 0: |
return {'display': 'inline-block'} |
return {'display': 'none'} |
@app.callback( |
Output("results-section", "style"), |
Input('process-button', 'n_clicks') |
) |
def display_results_section(n_clicks): |
""" |
Displays the results section based on the number of clicks on the process button. |
Parameters: |
n_clicks (int): The number of times the process button has been clicked. |
Returns: |
dict: A dictionary containing the CSS style for the results section. |
""" |
if n_clicks and n_clicks > 0: |
return {"display": "block"} |
else: |
return {"display": "none"} |
@app.callback( |
[Output('log-reader', 'children'), |
Output('acc-visualization', 'figure'), |
Output('acc-visualization', 'style'), |
Output('loss-visualization', 'figure'), |
Output('loss-visualization', 'style')], |
[Input('interval-widget', 'n_intervals'), |
Input('process-button', 'n_clicks')], |
[State('selected-model', 'value'), |
State('devcloud-username-input', 'value'), |
State('devcloud-password-input', 'value')] |
) |
def refresh_metrics(n_intervals, nb_clicks, selected_model, devcloud_username, devcloud_password): |
""" |
Updates the log display and training metrics based on user actions and intervals. |
This function handles the following: |
- Executes the training script when the run button is clicked and updates the logs. |
- Periodically checks for new training metrics and updates the accuracy and loss graphs. |
- Manages the display of the log and metrics components based on the training status. |
Args: |
n_intervals (int): The number of intervals that have passed for the interval component. |
nb_clicks (int): The number of times the run button has been clicked. |
selected_model (str): The selected model from the dropdown. |
devcloud_username (str): The username for authentication. |
devcloud_password (str): The password for authentication. |
Returns: |
tuple: A tuple containing: |
- str: The updated log messages. |
- dict: The figure data for the accuracy graph. |
- dict: The CSS style to display or hide the accuracy graph. |
- dict: The figure data for the loss graph. |
- dict: The CSS style to display or hide the loss graph. |
Raises: |
PreventUpdate: If the callback context is not triggered by a relevant input. |
""" |
global logs, new_training |
callback_context = dash.callback_context |
if not callback_context.triggered: |
raise PreventUpdate |
button = callback_context.triggered[0]['prop_id'].split('.')[0] |
if button == 'process-button' and nb_clicks: |
if devcloud_username and devcloud_password: |
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
execute_async(st_script, devcloud_username, devcloud_password) |
new_training = True |
logs.append("Starting application ...") |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
else: |
logs.append("Please enter both ST Developer Cloud username and password:") |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
elif button == 'interval-widget': |
if not new_training: |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
outputs_folder = "experiments_outputs" |
if not os.path.exists(outputs_folder): |
os.makedirs(outputs_folder) |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
dated_directories = [d for d in os.listdir(outputs_folder) if os.path.isdir(os.path.join(outputs_folder, d)) and d.startswith('20')] |
if dated_directories: |
recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S')) |
train_metrics_file = os.path.join(outputs_folder, recent_directory, 'logs', 'metrics', 'train_metrics.csv') |
print(f"Metrics file : {train_metrics_file}") |
if os.path.exists(train_metrics_file) and new_training: |
metrics_dataframe = pd.read_csv(train_metrics_file) |
if not metrics_dataframe.empty: |
figures = [] |
metrics_pairs = [ |
('accuracy', 'val_accuracy'), |
('loss', 'val_loss'), |
('oks', 'val_oks'), |
('val_map',) |
] |
for pair in metrics_pairs: |
if len(pair) == 2: |
train_metric, val_metric = pair |
if train_metric in metrics_dataframe.columns and val_metric in metrics_dataframe.columns: |
fig = { |
'data': [ |
{ |
'x': metrics_dataframe['epoch'], |
'y': metrics_dataframe[train_metric], |
'type': 'line', |
'name': train_metric.capitalize(), |
'line': {'color': '#FFD200', 'width': 2, 'dash': 'solid'}, |
'hoverinfo': 'x+y+name', |
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
}, |
{ |
'x': metrics_dataframe['epoch'], |
'y': metrics_dataframe[val_metric], |
'type': 'line', |
'name': val_metric.capitalize(), |
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
'hoverinfo': 'x+y+name', |
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
} |
], |
'layout': { |
'xaxis': { |
'title': 'Epochs', |
'showgrid': True, |
'gridcolor': '#EEEFF1', |
'tickangle': 45 |
}, |
'yaxis': { |
'title': train_metric.capitalize(), |
'showgrid': True, |
'gridcolor': '#EEEFF1' |
}, |
'showlegend': True, |
'legend': { |
'x': 1, |
'y': 1, |
'traceorder': 'normal', |
'font': {'size': 10}, |
'bgcolor': '#EEEFF1', |
'bordercolor': '#A6ADB5', |
'borderwidth': 1 |
}, |
'hovermode': 'closest', |
'plot_bgcolor': '#ffffff' |
} |
} |
figures.append(fig) |
elif len(pair) == 1: |
val_metric = pair[0] |
if val_metric in metrics_dataframe.columns: |
fig = { |
'data': [ |
{ |
'x': metrics_dataframe['epoch'], |
'y': metrics_dataframe[val_metric], |
'type': 'line', |
'name': val_metric.capitalize(), |
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
'hoverinfo': 'x+y+name', |
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
} |
], |
'layout': { |
'xaxis': { |
'title': 'Epochs', |
'showgrid': True, |
'gridcolor': '#EEEFF1', |
'tickangle': 45 |
}, |
'yaxis': { |
'title': val_metric.capitalize(), |
'showgrid': True, |
'gridcolor': '#EEEFF1' |
}, |
'showlegend': True, |
'legend': { |
'x': 1, |
'y': 1, |
'traceorder': 'normal', |
'font': {'size': 10}, |
'bgcolor': '#EEEFF1', |
'bordercolor': '#A6ADB5', |
'borderwidth': 1 |
}, |
'hovermode': 'closest', |
'plot_bgcolor': '#ffffff' |
} |
} |
figures.append(fig) |
if figures: |
return "\n".join(logs), figures[0], {'display': 'block'}, figures[1] if len(figures) > 1 else {}, {'display': 'block'} |
else: |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
else: |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
else: |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
else: |
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
raise PreventUpdate |
@app.callback( |
Output('submission-outcome', 'children'), |
[Input('apply-button', 'n_clicks'), |
Input('process-button', 'n_clicks')], |
[State({'type': 'yaml-setting', 'index': ALL}, 'id'), |
State({'type': 'yaml-setting', 'index': ALL}, 'value'), |
State('selected-model', 'value'), |
State('devcloud-username-input', 'value'), |
State('devcloud-password-input', 'value')] |
) |
def process_button_actions(submit_clicks, exec_nb_clicks, form_input_ids, form_input_values, selected_model, devcloud_username, devcloud_password): |
""" |
Handles the actions triggered by the submit and run buttons. |
This function processes the form data when the submit button is clicked, |
updates the corresponding YAML file, and executes the training script when |
the run button is clicked. |
Args: |
submit_clicks (int): The number of times the submit button has been clicked. |
exec_nb_clicks (int): The number of times the execution/run button has been clicked. |
form_input_ids (list): A list of dictionaries containing the IDs of the form inputs. |
form_input_values (list): A list of values from the form inputs. |
selected_model (str): The selected model from the dropdown. |
devcloud_username (str): The username for DevCloud authentication. |
devcloud_password (str): The password for DevCloud authentication. |
Returns: |
str: A message indicating the result of the action, such as successful YAML update or script execution status. |
Raises: |
PreventUpdate: If the callback context is not triggered by a relevant input or if no action is taken. |
""" |
new_fields = [] |
callback_context = dash.callback_context |
if not callback_context.triggered: |
raise PreventUpdate |
triggered_button = callback_context.triggered[0]['prop_id'].split('.')[0] |
if triggered_button == 'apply-button': |
if submit_clicks: |
try: |
form_fields_data = {} |
for i in range(len(form_input_ids)): |
input_id = form_input_ids[i]['index'] |
input_value = form_input_values[i] |
form_fields_data[input_id] = input_value |
yaml_file_path = local_yamls.get(selected_model) |
if yaml_file_path : |
yaml_parser = ruamel.yaml.YAML() |
with open(yaml_file_path , 'r') as file: |
current_yaml_data = yaml_parser.load(file) |
updated_yaml_data = process_form_configs(form_fields_data) |
for key, value in updated_yaml_data.items(): |
keys = key.split('.') |
nested_dict = current_yaml_data |
for k in keys[:-1]: |
nested_dict = nested_dict.setdefault(k, {}) |
if nested_dict[keys[-1]] != value: |
nested_dict[keys[-1]] = value |
new_fields.append(key) |
with open(yaml_file_path , 'w') as file: |
yaml_parser.dump(current_yaml_data, file) |
return f"User config yaml file has been updated successfully ! Updated fields are: {', '.join(new_fields)}" |
else: |
return f"ERROR: No user config yaml found for '{selected_model}'." |
except Exception as e: |
return f"ERROR: UPDATING USER CONFIG YAML file: {e}" |
else: |
raise PreventUpdate |
elif triggered_button == 'process-button': |
if exec_nb_clicks: |
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
execute_async(st_script, devcloud_username, devcloud_password) |
return "Application is running ..." |
else: |
raise PreventUpdate |
@app.callback( |
Output('download-action', 'style'), |
[Input('interval-widget', 'n_intervals')], |
[State('selected-model', 'value')] |
) |
def toggle_download_button(n_intervals, selected_model): |
""" |
Toggles the display of the download button based on the existence of output directories. |
This function checks if the output directories for the selected model exist and |
toggles the display of the download button accordingly. |
Args: |
n_intervals (int): The number of intervals that have passed for the interval component. |
model_choice (str): The selected model from the dropdown. |
Returns: |
dict: A dictionary containing the CSS style to either display or hide the download button. |
""" |
out_directory = os.path.join(os.getcwd(), "experiments_outputs") |
if not os.path.exists(out_directory ): |
return {'display': 'none'} |
output_subdirectories = [d for d in os.listdir(out_directory ) if os.path.isdir(os.path.join(out_directory , d)) and d.startswith('20')] |
if output_subdirectories: |
return {'display': 'block'} |
return {'display': 'none'} |
@app.callback( |
Output('download-resource', 'data'), |
[Input('download-action', 'n_clicks')], |
[State('selected-model', 'value')] |
) |
def generate_download_link(n_clicks, selected_model): |
""" |
Generates a download link based on the selected model and operation mode. |
This function reads the YAML configuration for the selected model, determines the operation mode, |
and generates a download link for the appropriate file (ZIP or ELF/BIN) based on the operation mode. |
Args: |
click_count (int): The number of times the download button has been clicked. |
selected_model (str): The selected model from the dropdown. |
Returns: |
dcc.send_file: A Dash component to send the file for download. |
Raises: |
PreventUpdate: If no relevant action is taken or the required files do not exist. |
""" |
if n_clicks is None: |
raise PreventUpdate |
output_directory = os.path.join(os.getcwd(), "./experiments_outputs") |
if not os.path.exists(output_directory ): |
raise PreventUpdate |
timestamped_directories = [d for d in os.listdir(output_directory ) if os.path.isdir(os.path.join(output_directory , d)) and d.startswith('20')] |
timestamped_directories = [ |
d for d in os.listdir(output_directory) |
if os.path.isdir(os.path.join(output_directory, d)) and d.startswith("20") |
] |
if timestamped_directories: |
recent_directory = max( |
timestamped_directories, |
key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S") |
) |
recent_directory_path = os.path.join(output_directory, recent_directory) |
zip_file_path = os.path.join(recent_directory_path, f"{recent_directory}.zip") |
if not os.path.exists(zip_file_path): |
create_archive(zip_file_path, recent_directory_path) |
if os.path.exists(zip_file_path): |
return dcc.send_file(zip_file_path) |
raise PreventUpdate |
@server.route('/download/<path:subpath>') |
def download_file(subpath): |
""" |
Route to download a file from the server. |
Parameters: |
- subpath (str): The subpath of the file to be downloaded, relative to the './experiments_outputs' directory. |
Returns: |
- Response: A Flask response object to send the file as an attachment if it exists. |
- tuple: A tuple containing an error message and a 404 status code if the file is not found. |
""" |
file_path = os.path.join(os.getcwd(), './experiments_outputs', subpath) |
if os.path.exists(file_path): |
return send_file(file_path, as_attachment=True) |
else: |
return "File not found", 404 |
if __name__ == '__main__': |
app.run_server(host='',port=7860, dev_tools_ui=True, dev_tools_hot_reload=True, threaded=True) |