File size: 2,314 Bytes
fbf7e95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
from process import multiprocess_pairs
from predict import predict_onnx
from tqdm import tqdm
import pandas as pd

from marcai.utils.parsing import load_records, record_dict
from marcai.utils import load_config

import csv

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
    parser.add_argument(
        "-p",
        "--pair-indices",
        help="File containing indices of comparisons",
        required=True,
    )
    parser.add_argument("-C", "--chunksize", help="Chunk size", type=int, default=50000)
    parser.add_argument(
        "-P", "--processes", help="Number of processes", type=int, default=1
    )
    parser.add_argument(
        "-m",
        "--model-dir",
        help="Directory containing model ONNX and YAML files",
        required=True,
    )
    parser.add_argument("-o", "--output", help="Output file", required=True)
    parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)

    args = parser.parse_args()

    config_path = f"{args.model_dir}/config.yaml"
    model_onnx = f"{args.model_dir}/model.onnx"

    config = load_config(config_path)

    # Load records
    print("Loading records...")
    records = []
    for path in args.inputs:
        records.extend([record_dict(r) for r in load_records(path)])

    records_df = pd.DataFrame(records)

    print(f"Loaded {len(records)} records.")

    print("Processing and comparing records...")
    written = False
    with open(args.pair_indices, "r") as indices_file:
        reader = csv.reader(indices_file)
        # Process records
        for df in tqdm(multiprocess_pairs(
            records_df, reader, args.chunksize, args.processes
        )):
            input_df = df[config["model"]["features"]]
            prediction = predict_onnx(model_onnx, input_df)
            df.loc[:, "prediction"] = prediction.squeeze()

            df = df[df["prediction"] >= args.threshold]

            if not df.empty:
                if not written:
                    df.to_csv(args.output, index=False)
                    written = True
                else:
                    df.to_csv(args.output, index=False, mode="a", header=False)
            
if __name__ == "__main__":
    main()