Runtime error
Runtime error
File size: 3,011 Bytes
d825710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""A little helper scripts to generate the requirements.txt and models.json with
the latest supported model versions based on the compatibility.json."""
from spacy.about import __compatibility__ as COMPAT_URL
from spacy.util import get_lang_class, is_compatible_version
from pathlib import Path
import requests
import typer
import srsly
URL_TEMPLATE = "{name}-{version}/{name}-{version}.tar.gz#egg={name}=={version}"
def main(
# fmt: off
spacy_version: str = typer.Argument(">=3.0.0,<3.1.0", help="The spaCy version range"),
spacy_streamlit_version: str = typer.Argument(">=1.0.0rc1,<1.1.0", help="The version range of spacy-streamlit"),
req_path: Path = typer.Option(Path(__file__).parent / "requirements.txt", "--requirements-path", "-rp", help="Path to requirements.txt"),
desc_path: Path = typer.Option(Path(__file__).parent / "models.json", "--models-json-path", "-mp", help="Path to models.json with model details for dropdown"),
package: str = typer.Option("spacy", "--package", "-p", help="The parent package (spacy, spacy-nightly, etc.)"),
exclude: str = typer.Option("en_vectors_web_lg", "--exclude", "-e", help="Comma-separated model names to exclude"),
# fmt: on
exclude = [name.strip() for name in exclude.split(",")]
r = requests.get(COMPAT_URL)
compat = r.json()["spacy"]
data = None
for version_option in compat:
if is_compatible_version(version_option, spacy_version):
data = compat[version_option]
if data is None:
raise ValueError(f"No compatible models found for {spacy_version}")
reqs = [
f"# Auto-generated by {Path(__file__).name}",
models = {}
for model_name, model_versions in data.items():
if model_name not in exclude and model_versions:
url = URL_TEMPLATE.format(name=model_name, version=model_versions[0])
# We do a quick check if the URL exists
r = requests.get(url, headers={"Range": "bytes=0"})
if r.status_code == 404:
print(f"Invalid package URL (skipping): {url}")
lang = model_name.split("_", 1)[0]
lang_name = get_lang_class(lang).__name__
models[model_name] = f"{lang_name} ({model_name})"
# Sort by human-readable language name, then by model size
sort_key = lambda x: f"{x[1].split(' ')[0]}_{['sm', 'md', 'lg', 'trf'].index(x[0].split('_')[-1])}"
models = {name: desc for name, desc in sorted(models.items(), key=sort_key)}
with Path(req_path).open("w", encoding="utf8") as f:
srsly.write_json(desc_path, models)
print(f"Generated requirements.txt and models.json for {len(reqs) - 1} models")
if __name__ == "__main__":