Spaces:
Build error
Build error
File size: 3,865 Bytes
343fa36 c6c0d26 343fa36 c6c0d26 343fa36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
Utils for the demo app.
"""
import os
import re
import subprocess
from demo import constants, state
from lczerolens import Lens, LczeroModel
from lczerolens.model import lczero as lczero_utils
def get_models_info(onnx=True, leela=True):
"""
Get the names of the models in the model directory.
"""
model_df = []
exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)"
if onnx:
for filename in os.listdir(constants.MODEL_DIRECTORY):
if filename.endswith(".onnx"):
match = re.search(exp, filename)
if match is None:
n_filters = -1
n_blocks = -1
else:
n_filters = int(match.group("n_filters"))
n_blocks = int(match.group("n_blocks"))
model_df.append(
[
filename,
"ONNX",
n_blocks,
n_filters,
]
)
if leela:
for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY):
if filename.endswith(".pb.gz"):
match = re.search(exp, filename)
if match is None:
n_filters = -1
n_blocks = -1
else:
n_filters = int(match.group("n_filters"))
n_blocks = int(match.group("n_blocks"))
model_df.append(
[
filename,
"LEELA",
n_blocks,
n_filters,
]
)
return model_df
def save_model(tmp_file_path):
"""
Save the model to the model directory.
"""
popen = subprocess.Popen(
["file", tmp_file_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
popen.wait()
if popen.returncode != 0:
raise RuntimeError
file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip()
rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc)
type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc)
if rename_match is None or type_match is None:
raise RuntimeError
model_name = rename_match.group("name")
model_type = type_match.group("type")
if model_type != "gzip":
raise RuntimeError
os.rename(
tmp_file_path,
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
)
try:
lczero_utils.describenet(
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
)
except RuntimeError:
os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz")
raise RuntimeError
def get_wrapper_from_state(model_name):
"""
Get the model wrapper from the state.
"""
if model_name in state.wrappers:
return state.wrappers[model_name]
else:
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
state.wrappers[model_name] = wrapper
return wrapper
def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs):
"""
Get the model wrapper and lens from the state.
"""
if model_name in state.wrappers:
wrapper = state.wrappers[model_name]
else:
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
state.wrappers[model_name] = wrapper
if lens_name in state.lenses[lens_type]:
lens = state.lenses[lens_type][lens_name]
else:
lens = Lens.from_name(lens_type, **kwargs)
if not lens.is_compatible(wrapper):
raise ValueError(f"Lens of type {lens_type} not compatible with model.")
state.lenses[lens_type][lens_name] = lens
return wrapper, lens
|