Spaces:
Sleeping
Sleeping
Change package entrypoint to parent CLI script
Browse files- marcai/cli.py +42 -0
- marcai/find_matches.py +7 -2
- marcai/predict.py +6 -3
- marcai/process.py +3 -3
- marcai/train.py +7 -5
- setup.cfg +2 -5
marcai/cli.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from . import train, predict, process, find_matches
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser(
|
7 |
+
description="Command-line interface for marcai package"
|
8 |
+
)
|
9 |
+
subparsers = parser.add_subparsers(required=True)
|
10 |
+
|
11 |
+
train_parser = subparsers.add_parser(
|
12 |
+
"train", parents=[train.args_parser()], help="Train a model", add_help=False
|
13 |
+
)
|
14 |
+
predict_parser = subparsers.add_parser(
|
15 |
+
"predict",
|
16 |
+
parents=[predict.args_parser()],
|
17 |
+
help="Make predictions using a trained model",
|
18 |
+
add_help=False,
|
19 |
+
)
|
20 |
+
process_parser = subparsers.add_parser(
|
21 |
+
"process", parents=[process.args_parser()], help="Process data", add_help=False
|
22 |
+
)
|
23 |
+
find_matches_parser = subparsers.add_parser(
|
24 |
+
"find_matches",
|
25 |
+
parents=[find_matches.args_parser()],
|
26 |
+
help="Find matches in data",
|
27 |
+
add_help=False,
|
28 |
+
)
|
29 |
+
|
30 |
+
train_parser.set_defaults(func=train.main)
|
31 |
+
predict_parser.set_defaults(func=predict.main)
|
32 |
+
process_parser.set_defaults(func=process.main)
|
33 |
+
find_matches_parser.set_defaults(func=find_matches.main)
|
34 |
+
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
args.func(args)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
main()
|
marcai/find_matches.py
CHANGED
@@ -10,7 +10,7 @@ from marcai.utils import load_config
|
|
10 |
from marcai.utils.parsing import load_records, record_dict
|
11 |
|
12 |
|
13 |
-
def
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
|
16 |
parser.add_argument(
|
@@ -32,7 +32,12 @@ def main():
|
|
32 |
parser.add_argument("-o", "--output", help="Output file", required=True)
|
33 |
parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
config_path = f"{args.model_dir}/config.yaml"
|
38 |
model_onnx = f"{args.model_dir}/model.onnx"
|
|
|
10 |
from marcai.utils.parsing import load_records, record_dict
|
11 |
|
12 |
|
13 |
+
def args_parser():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
|
16 |
parser.add_argument(
|
|
|
32 |
parser.add_argument("-o", "--output", help="Output file", required=True)
|
33 |
parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
|
34 |
|
35 |
+
return parser
|
36 |
+
|
37 |
+
|
38 |
+
def main():
|
39 |
+
|
40 |
+
args = args_parser().parse_args()
|
41 |
|
42 |
config_path = f"{args.model_dir}/config.yaml"
|
43 |
model_onnx = f"{args.model_dir}/model.onnx"
|
marcai/predict.py
CHANGED
@@ -23,8 +23,7 @@ def predict_onnx(model_onnx_path, data):
|
|
23 |
|
24 |
return ort_outs
|
25 |
|
26 |
-
|
27 |
-
def main():
|
28 |
parser = argparse.ArgumentParser()
|
29 |
parser.add_argument(
|
30 |
"-i", "--input", help="Path to preprocessed data file", required=True
|
@@ -42,8 +41,12 @@ def main():
|
|
42 |
default=1024,
|
43 |
type=int,
|
44 |
)
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
args =
|
47 |
|
48 |
config_path = f"{args.model_dir}/config.yaml"
|
49 |
model_onnx = f"{args.model_dir}/model.onnx"
|
|
|
23 |
|
24 |
return ort_outs
|
25 |
|
26 |
+
def args_parser():
|
|
|
27 |
parser = argparse.ArgumentParser()
|
28 |
parser.add_argument(
|
29 |
"-i", "--input", help="Path to preprocessed data file", required=True
|
|
|
41 |
default=1024,
|
42 |
type=int,
|
43 |
)
|
44 |
+
return parser
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
|
49 |
+
args = args_parser().parse_args()
|
50 |
|
51 |
config_path = f"{args.model_dir}/config.yaml"
|
52 |
model_onnx = f"{args.model_dir}/model.onnx"
|
marcai/process.py
CHANGED
@@ -190,7 +190,7 @@ def process(df0, df1):
|
|
190 |
return result_df
|
191 |
|
192 |
|
193 |
-
def
|
194 |
parser = argparse.ArgumentParser(
|
195 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
196 |
)
|
@@ -217,13 +217,13 @@ def parse_args():
|
|
217 |
default=1,
|
218 |
)
|
219 |
|
220 |
-
return parser
|
221 |
|
222 |
|
223 |
def main():
|
224 |
|
225 |
start = time.time()
|
226 |
-
args = parse_args()
|
227 |
|
228 |
# Load records
|
229 |
print("Loading records...")
|
|
|
190 |
return result_df
|
191 |
|
192 |
|
193 |
+
def args_parser():
|
194 |
parser = argparse.ArgumentParser(
|
195 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
196 |
)
|
|
|
217 |
default=1,
|
218 |
)
|
219 |
|
220 |
+
return parser
|
221 |
|
222 |
|
223 |
def main():
|
224 |
|
225 |
start = time.time()
|
226 |
+
args = args_parser().parse_args()
|
227 |
|
228 |
# Load records
|
229 |
print("Loading records...")
|
marcai/train.py
CHANGED
@@ -88,12 +88,14 @@ def train(name=None):
|
|
88 |
archive.add(save_dir, arcname=os.path.basename(save_dir))
|
89 |
|
90 |
|
91 |
-
def
|
92 |
parser = argparse.ArgumentParser()
|
93 |
-
parser.add_argument(
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
97 |
|
98 |
train(args.run_name)
|
99 |
|
|
|
88 |
archive.add(save_dir, arcname=os.path.basename(save_dir))
|
89 |
|
90 |
|
91 |
+
def args_parser():
|
92 |
parser = argparse.ArgumentParser()
|
93 |
+
parser.add_argument("-n", "--run-name", help="Name for training run", required=True)
|
94 |
+
return parser
|
95 |
+
|
96 |
+
def main():
|
97 |
+
|
98 |
+
args = args_parser().parse_args()
|
99 |
|
100 |
train(args.run_name)
|
101 |
|
setup.cfg
CHANGED
@@ -7,8 +7,5 @@ packages = find:
|
|
7 |
|
8 |
[options.entry_points]
|
9 |
console_scripts =
|
10 |
-
|
11 |
-
|
12 |
-
train = marcai:train.main
|
13 |
-
find_matches = marcai:find_matches.main
|
14 |
-
|
|
|
7 |
|
8 |
[options.entry_points]
|
9 |
console_scripts =
|
10 |
+
marc-ai = marcai:cli.main
|
11 |
+
|
|
|
|
|
|