""" Gradio interface for converting models. """ import os import uuid import re import subprocess import gradio as gr from demo import constants, utils from lczerolens import backends def get_models_info(onnx=True, leela=True): """ Get the names of the models in the model directory. """ model_df = [] exp = r"(?P\d+)x(?P\d+)" if onnx: for filename in os.listdir(constants.ONNX_MODEL_DIRECTORY): if filename.endswith(".onnx"): match = re.search(exp, filename) if match is None: n_filters = -1 n_blocks = -1 else: n_filters = int(match.group("n_filters")) n_blocks = int(match.group("n_blocks")) model_df.append( [ filename, "ONNX", n_blocks, n_filters, ] ) if leela: for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY): if filename.endswith(".pb.gz"): match = re.search(exp, filename) if match is None: n_filters = -1 n_blocks = -1 else: n_filters = int(match.group("n_filters")) n_blocks = int(match.group("n_blocks")) model_df.append( [ filename, "LEELA", n_blocks, n_filters, ] ) return model_df def save_model(tmp_file_path): """ Save the model to the model directory. """ popen = subprocess.Popen( ["file", tmp_file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) popen.wait() if popen.returncode != 0: raise RuntimeError file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip() rename_match = re.search(r"was\s\"(?P.+)\"", file_desc) type_match = re.search(r"\:\s(?P[a-zA-Z]+)", file_desc) if rename_match is None or type_match is None: raise RuntimeError model_name = rename_match.group("name") model_type = type_match.group("type") if model_type != "gzip": raise RuntimeError os.rename( tmp_file_path, f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz", ) try: backends.describenet( f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz", ) except RuntimeError: os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz") raise RuntimeError def list_models(): """ List the models in the model directory. """ models_info = get_models_info() return sorted([[model_info[0]] for model_info in models_info]) def on_select_model_df( evt: gr.SelectData, ): """ When a model is selected, update the statement. """ return evt.value def convert_model( model_name: str, ): """ Convert the model. """ if model_name == "": gr.Warning( "Please select a model.", ) return list_models(), "" if model_name.endswith(".onnx"): gr.Warning( "ONNX conversion not implemented.", ) return list_models(), "" try: backends.convert_to_onnx( f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", f"{constants.ONNX_MODEL_DIRECTORY}/{model_name[:-6]}.onnx", ) except RuntimeError: gr.Warning( f"Could not convert net at `{model_name}`.", ) return list_models(), "Conversion failed" return list_models(), "Conversion successful" def upload_model( model_file: gr.File, ): """ Convert the model. """ if model_file is None: gr.Warning( "File not uploaded.", ) return list_models() try: id = uuid.uuid4() tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}" with open( tmp_file_path, "wb", ) as f: f.write(model_file) save_model(tmp_file_path) except RuntimeError: gr.Warning( "Invalid file type.", ) finally: if os.path.exists(tmp_file_path): os.remove(tmp_file_path) return list_models() def get_model_description( model_name: str, ): """ Get the model description. """ if model_name == "": gr.Warning( "Please select a model.", ) return "" if model_name.endswith(".onnx"): gr.Warning( "ONNX description not implemented.", ) return "" try: description = backends.describenet( f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", ) except RuntimeError: raise gr.Error( f"Could not describe net at `{model_name}`.", ) return description def get_model_path( model_name: str, ): """ Get the model path. """ if model_name == "": gr.Warning( "Please select a model.", ) return None if model_name.endswith(".onnx"): return f"{constants.ONNX_MODEL_DIRECTORY}/{model_name}" else: return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}" with gr.Blocks() as interface: model_file = gr.File(type="binary") upload_button = gr.Button( value="Upload", ) with gr.Row(): with gr.Column(scale=2): model_df = gr.Dataframe( headers=["Available models"], datatype=["str"], interactive=False, type="array", value=list_models, ) with gr.Column(scale=1): with gr.Row(): model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) conversion_status = gr.Textbox( label="Conversion status", lines=1, interactive=False, ) convert_button = gr.Button( value="Convert", ) describe_button = gr.Button( value="Describe model", ) model_description = gr.Textbox( label="Model description", lines=1, interactive=False, ) download_button = gr.Button( value="Get download link", ) download_file = gr.File( type="filepath", label="Download link", interactive=False, ) model_df.select( on_select_model_df, None, model_name, ) upload_button.click( upload_model, model_file, model_df, ) convert_button.click( convert_model, model_name, [model_df, conversion_status], ) describe_button.click( get_model_description, model_name, model_description, ) download_button.click( get_model_path, model_name, download_file, )