lczerolens-demo / app /utils.py
Xmaster6y's picture
no more factories
c6c0d26
"""
Utils for the demo app.
"""
import os
import re
import subprocess
from demo import constants, state
from lczerolens import Lens, LczeroModel
from lczerolens.model import lczero as lczero_utils
def get_models_info(onnx=True, leela=True):
"""
Get the names of the models in the model directory.
"""
model_df = []
exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)"
if onnx:
for filename in os.listdir(constants.MODEL_DIRECTORY):
if filename.endswith(".onnx"):
match = re.search(exp, filename)
if match is None:
n_filters = -1
n_blocks = -1
else:
n_filters = int(match.group("n_filters"))
n_blocks = int(match.group("n_blocks"))
model_df.append(
[
filename,
"ONNX",
n_blocks,
n_filters,
]
)
if leela:
for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY):
if filename.endswith(".pb.gz"):
match = re.search(exp, filename)
if match is None:
n_filters = -1
n_blocks = -1
else:
n_filters = int(match.group("n_filters"))
n_blocks = int(match.group("n_blocks"))
model_df.append(
[
filename,
"LEELA",
n_blocks,
n_filters,
]
)
return model_df
def save_model(tmp_file_path):
"""
Save the model to the model directory.
"""
popen = subprocess.Popen(
["file", tmp_file_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
popen.wait()
if popen.returncode != 0:
raise RuntimeError
file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip()
rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc)
type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc)
if rename_match is None or type_match is None:
raise RuntimeError
model_name = rename_match.group("name")
model_type = type_match.group("type")
if model_type != "gzip":
raise RuntimeError
os.rename(
tmp_file_path,
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
)
try:
lczero_utils.describenet(
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
)
except RuntimeError:
os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz")
raise RuntimeError
def get_wrapper_from_state(model_name):
"""
Get the model wrapper from the state.
"""
if model_name in state.wrappers:
return state.wrappers[model_name]
else:
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
state.wrappers[model_name] = wrapper
return wrapper
def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs):
"""
Get the model wrapper and lens from the state.
"""
if model_name in state.wrappers:
wrapper = state.wrappers[model_name]
else:
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
state.wrappers[model_name] = wrapper
if lens_name in state.lenses[lens_type]:
lens = state.lenses[lens_type][lens_name]
else:
lens = Lens.from_name(lens_type, **kwargs)
if not lens.is_compatible(wrapper):
raise ValueError(f"Lens of type {lens_type} not compatible with model.")
state.lenses[lens_type][lens_name] = lens
return wrapper, lens