nielklug's picture
init
6ed21b9
import os
BENEPAR_SERVER_INDEX = "https://kitaev.com/benepar/index.xml"
_downloader = None
def get_downloader():
global _downloader
if _downloader is None:
import nltk.downloader
_downloader = nltk.downloader.Downloader(server_index_url=BENEPAR_SERVER_INDEX)
return _downloader
def download(*args, **kwargs):
return get_downloader().download(*args, **kwargs)
def locate_model(name):
if os.path.exists(name):
return name
elif "/" not in name and "." not in name:
import nltk.data
try:
nltk_loc = nltk.data.find(f"models/{name}")
return nltk_loc.path
except LookupError as e:
arg = e.args[0].replace("nltk.download", "benepar.download")
raise LookupError(arg)
raise LookupError("Can't find {}".format(name))
def load_trained_model(model_name_or_path):
model_path = locate_model(model_name_or_path)
from ..parse_chart import ChartParser
parser = ChartParser.from_trained(model_path)
return parser