Spaces:
Sleeping
Sleeping
import argparse | |
from process import multiprocess_pairs | |
from predict import predict_onnx | |
from tqdm import tqdm | |
import pandas as pd | |
from marcai.utils.parsing import load_records, record_dict | |
from marcai.utils import load_config | |
import csv | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True) | |
parser.add_argument( | |
"-p", | |
"--pair-indices", | |
help="File containing indices of comparisons", | |
required=True, | |
) | |
parser.add_argument("-C", "--chunksize", help="Chunk size", type=int, default=50000) | |
parser.add_argument( | |
"-P", "--processes", help="Number of processes", type=int, default=1 | |
) | |
parser.add_argument( | |
"-m", | |
"--model-dir", | |
help="Directory containing model ONNX and YAML files", | |
required=True, | |
) | |
parser.add_argument("-o", "--output", help="Output file", required=True) | |
parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float) | |
args = parser.parse_args() | |
config_path = f"{args.model_dir}/config.yaml" | |
model_onnx = f"{args.model_dir}/model.onnx" | |
config = load_config(config_path) | |
# Load records | |
print("Loading records...") | |
records = [] | |
for path in args.inputs: | |
records.extend([record_dict(r) for r in load_records(path)]) | |
records_df = pd.DataFrame(records) | |
print(f"Loaded {len(records)} records.") | |
print("Processing and comparing records...") | |
written = False | |
with open(args.pair_indices, "r") as indices_file: | |
reader = csv.reader(indices_file) | |
# Process records | |
for df in tqdm(multiprocess_pairs( | |
records_df, reader, args.chunksize, args.processes | |
)): | |
input_df = df[config["model"]["features"]] | |
prediction = predict_onnx(model_onnx, input_df) | |
df.loc[:, "prediction"] = prediction.squeeze() | |
df = df[df["prediction"] >= args.threshold] | |
if not df.empty: | |
if not written: | |
df.to_csv(args.output, index=False) | |
written = True | |
else: | |
df.to_csv(args.output, index=False, mode="a", header=False) | |
if __name__ == "__main__": | |
main() | |