File size: 1,505 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Script to create MLE model for retriever component of CRB-CRS.

For ReDial, use the following command:
python -m script.crb_crs.create_mle \
    --corpus_file data/redial/GT_corpus_tokens.txt \
    --output_file data/models/crb_crs/mle_model.pkl
"""

import argparse
import logging
import os

from src.model.crb_crs.retriever.mle_model import NGramMLE


def parse_args() -> argparse.Namespace:
    """Parses command line arguments."""
    parser = argparse.ArgumentParser(
        description="Create MLE model for retriever component of CRB-CRS."
    )
    parser.add_argument(
        "--corpus_file",
        type=str,
        required=True,
        help="Path to the corpus file.",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        required=True,
        help="Path to save the created MLE model.",
    )
    parser.add_argument(
        "--n",
        type=int,
        default=2,
        help="Maximum n-gram order. Defaults to 2.",
    )
    return parser.parse_args()


def main(args: argparse.Namespace) -> None:
    """Creates MLE model for retriever component of CRB-CRS.

    Args:
        args: Command line arguments.
    """
    model = NGramMLE(args.n, args.corpus_file)

    model.create_ngrams()

    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
    model.save(args.output_file)
    logging.info(f"MLE model saved at {args.output_file}.")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main(parse_args())