|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
from argparse import ArgumentParser |
|
from dataclasses import dataclass, field |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
|
|
|
|
@dataclass |
|
class A(FairseqDataclass): |
|
data: str = field(default="test", metadata={"help": "the data input"}) |
|
num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) |
|
|
|
|
|
@dataclass |
|
class B(FairseqDataclass): |
|
bar: A = field(default=A()) |
|
foo: int = field(default=0, metadata={"help": "not a bar"}) |
|
|
|
|
|
@dataclass |
|
class D(FairseqDataclass): |
|
arch: A = field(default=A()) |
|
foo: int = field(default=0, metadata={"help": "not a bar"}) |
|
|
|
|
|
@dataclass |
|
class C(FairseqDataclass): |
|
data: str = field(default="test", metadata={"help": "root level data input"}) |
|
encoder: D = field(default=D()) |
|
decoder: A = field(default=A()) |
|
lr: int = field(default=0, metadata={"help": "learning rate"}) |
|
|
|
|
|
class TestDataclassUtils(unittest.TestCase): |
|
def test_argparse_convert_basic(self): |
|
parser = ArgumentParser() |
|
gen_parser_from_dataclass(parser, A(), True) |
|
args = parser.parse_args(["--num-layers", '10', "the/data/path"]) |
|
self.assertEqual(args.num_layers, 10) |
|
self.assertEqual(args.data, "the/data/path") |
|
|
|
def test_argparse_recursive(self): |
|
parser = ArgumentParser() |
|
gen_parser_from_dataclass(parser, B(), True) |
|
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) |
|
self.assertEqual(args.num_layers, 10) |
|
self.assertEqual(args.foo, 10) |
|
self.assertEqual(args.data, "the/data/path") |
|
|
|
def test_argparse_recursive_prefixing(self): |
|
self.maxDiff = None |
|
parser = ArgumentParser() |
|
gen_parser_from_dataclass(parser, C(), True, "") |
|
args = parser.parse_args( |
|
[ |
|
"--encoder-arch-data", |
|
"ENCODER_ARCH_DATA", |
|
"--encoder-arch-num-layers", |
|
"10", |
|
"--encoder-foo", |
|
"10", |
|
"--decoder-data", |
|
"DECODER_DATA", |
|
"--decoder-num-layers", |
|
"10", |
|
"--lr", |
|
"10", |
|
"the/data/path", |
|
] |
|
) |
|
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") |
|
self.assertEqual(args.encoder_arch_num_layers, 10) |
|
self.assertEqual(args.encoder_foo, 10) |
|
self.assertEqual(args.decoder_data, "DECODER_DATA") |
|
self.assertEqual(args.decoder_num_layers, 10) |
|
self.assertEqual(args.lr, 10) |
|
self.assertEqual(args.data, "the/data/path") |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|