""" Gradio interface for converting models. """ import os import uuid import gradio as gr from demo import constants, utils from lczerolens.model import lczero as lczero_utils def list_models(): """ List the models in the model directory. """ models_info = utils.get_models_info() return sorted([[model_info[0]] for model_info in models_info]) def on_select_model_df( evt: gr.SelectData, ): """ When a model is selected, update the statement. """ return evt.value def convert_model( model_name: str, ): """ Convert the model. """ if model_name == "": gr.Warning( "Please select a model.", ) return list_models(), "" if model_name.endswith(".onnx"): gr.Warning( "ONNX conversion not implemented.", ) return list_models(), "" try: lczero_utils.convert_to_onnx( f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", f"{constants.MODEL_DIRECTORY}/{model_name[:-6]}.onnx", ) except RuntimeError: gr.Warning( f"Could not convert net at `{model_name}`.", ) return list_models(), "Conversion failed" return list_models(), "Conversion successful" def upload_model( model_file: gr.File, ): """ Convert the model. """ if model_file is None: gr.Warning( "File not uploaded.", ) return list_models() try: id = uuid.uuid4() tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}" with open( tmp_file_path, "wb", ) as f: f.write(model_file) utils.save_model(tmp_file_path) except RuntimeError: gr.Warning( "Invalid file type.", ) finally: if os.path.exists(tmp_file_path): os.remove(tmp_file_path) return list_models() def get_model_description( model_name: str, ): """ Get the model description. """ if model_name == "": gr.Warning( "Please select a model.", ) return "" if model_name.endswith(".onnx"): gr.Warning( "ONNX description not implemented.", ) return "" try: description = lczero_utils.describenet( f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", ) except RuntimeError: raise gr.Error( f"Could not describe net at `{model_name}`.", ) return description def get_model_path( model_name: str, ): """ Get the model path. """ if model_name == "": gr.Warning( "Please select a model.", ) return None if model_name.endswith(".onnx"): return f"{constants.MODEL_DIRECTORY}/{model_name}" else: return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}" with gr.Blocks() as interface: model_file = gr.File(type="binary") upload_button = gr.Button( value="Upload", ) with gr.Row(): with gr.Column(scale=2): model_df = gr.Dataframe( headers=["Available models"], datatype=["str"], interactive=False, type="array", value=list_models, ) with gr.Column(scale=1): with gr.Row(): model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) conversion_status = gr.Textbox( label="Conversion status", lines=1, interactive=False, ) convert_button = gr.Button( value="Convert", ) describe_button = gr.Button( value="Describe model", ) model_description = gr.Textbox( label="Model description", lines=1, interactive=False, ) download_button = gr.Button( value="Get download link", ) download_file = gr.File( type="filepath", label="Download link", interactive=False, ) model_df.select( on_select_model_df, None, model_name, ) upload_button.click( upload_model, model_file, model_df, ) convert_button.click( convert_model, model_name, [model_df, conversion_status], ) describe_button.click( get_model_description, model_name, model_description, ) download_button.click( get_model_path, model_name, download_file, )