"""
AttackArgs Class
================
"""

from dataclasses import dataclass, field
import json
import os
import sys
import time
from typing import Dict, Optional

import textattack
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file

from .attack import Attack
from .dataset_args import DatasetArgs
from .model_args import ModelArgs

ATTACK_RECIPE_NAMES = {
    "alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
    "bae": "textattack.attack_recipes.BAEGarg2019",
    "bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
    "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
    "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
    "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
    "input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
    "kuleshov": "textattack.attack_recipes.Kuleshov2017",
    "morpheus": "textattack.attack_recipes.MorpheusTan2020",
    "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
    "textbugger": "textattack.attack_recipes.TextBuggerLi2018",
    "textfooler": "textattack.attack_recipes.TextFoolerJin2019",
    "pwws": "textattack.attack_recipes.PWWSRen2019",
    "iga": "textattack.attack_recipes.IGAWang2019",
    "pruthi": "textattack.attack_recipes.Pruthi2019",
    "pso": "textattack.attack_recipes.PSOZang2020",
    "checklist": "textattack.attack_recipes.CheckList2020",
    "clare": "textattack.attack_recipes.CLARE2020",
    "a2t": "textattack.attack_recipes.A2TYoo2021",
}


BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
    "random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
    "word-deletion": "textattack.transformations.WordDeletion",
    "word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
    "word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
    "word-swap-inflections": "textattack.transformations.WordSwapInflections",
    "word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
    "word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
    "word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
    "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
    "word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
    "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
    "word-swap-hownet": "textattack.transformations.WordSwapHowNet",
    "word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
}


WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
    "word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
}


CONSTRAINT_CLASS_NAMES = {
    #
    # Semantics constraints
    #
    "embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
    "bert": "textattack.constraints.semantics.sentence_encoders.BERT",
    "infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
    "thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
    "use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
    "muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
    "bert-score": "textattack.constraints.semantics.BERTScore",
    #
    # Grammaticality constraints
    #
    "lang-tool": "textattack.constraints.grammaticality.LanguageTool",
    "part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
    "goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
    "gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
    "learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
    "cola": "textattack.constraints.grammaticality.COLA",
    #
    # Overlap constraints
    #
    "bleu": "textattack.constraints.overlap.BLEU",
    "chrf": "textattack.constraints.overlap.chrF",
    "edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
    "meteor": "textattack.constraints.overlap.METEOR",
    "max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
    #
    # Pre-transformation constraints
    #
    "repeat": "textattack.constraints.pre_transformation.RepeatModification",
    "stopword": "textattack.constraints.pre_transformation.StopwordModification",
    "max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
}


SEARCH_METHOD_CLASS_NAMES = {
    "beam-search": "textattack.search_methods.BeamSearch",
    "greedy": "textattack.search_methods.GreedySearch",
    "ga-word": "textattack.search_methods.GeneticAlgorithm",
    "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
    "pso": "textattack.search_methods.ParticleSwarmOptimization",
}


GOAL_FUNCTION_CLASS_NAMES = {
    #
    # Classification goal functions
    #
    "targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
    "untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
    "input-reduction": "textattack.goal_functions.classification.InputReduction",
    #
    # Text goal functions
    #
    "minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
    "non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
    "text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
}


@dataclass
class AttackArgs:
    """Attack arguments to be passed to :class:`~textattack.Attacker`.

    Args:
        num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
            The number of examples to attack. :obj:`-1` for entire dataset.
        num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
            The number of successful adversarial examples we want. This is different from :obj:`num_examples`
            as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
            until we have `N` successful cases.

            .. note::
                If set, this argument overrides `num_examples` argument.
        num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
            The offset index to start at in the dataset.
        attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to run attack until total of `N` examples have been attacked (and not skipped).
        shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
            If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
            the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
            :obj:`shuffle` can now be used with checkpoint saving.
        query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
            The maximum number of model queries allowed per example attacked.
            If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).

            .. note::
                Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
        checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
            If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
        checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
            The directory to save checkpoint files.
        random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
            Random seed for reproducibility.
        parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
            If :obj:`True`, run attack using multiple CPUs/GPUs.
        num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
            Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
            then 2 processes will be running in each GPU.
        log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
            If set, save attack logs as a `.txt` file to the directory specified by this argument.
            If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
        log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
            If set, save attack logs as a CSV file to the directory specified by this argument.
            If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
        csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
            Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
            :obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
        log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
            If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
            Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
            three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
        log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
            If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
            Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
            key and its corresponding value: :obj:`"project"`.
        disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Disable displaying individual attack results to stdout.
        silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
        enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
    """

    num_examples: int = 10
    num_successful_examples: int = None
    num_examples_offset: int = 0
    attack_n: bool = False
    shuffle: bool = False
    query_budget: int = None
    checkpoint_interval: int = None
    checkpoint_dir: str = "checkpoints"
    random_seed: int = 765  # equivalent to sum((ord(c) for c in "TEXTATTACK"))
    parallel: bool = False
    num_workers_per_device: int = 1
    log_to_txt: str = None
    log_to_csv: str = None
    log_summary_to_json: str = None
    csv_coloring_style: str = "file"
    log_to_visdom: dict = None
    log_to_wandb: dict = None
    disable_stdout: bool = False
    silent: bool = False
    enable_advance_metrics: bool = False
    metrics: Optional[Dict] = None

    def __post_init__(self):
        if self.num_successful_examples:
            self.num_examples = None
        if self.num_examples:
            assert (
                self.num_examples >= 0 or self.num_examples == -1
            ), "`num_examples` must be greater than or equal to 0 or equal to -1."
        if self.num_successful_examples:
            assert (
                self.num_successful_examples >= 0
            ), "`num_examples` must be greater than or equal to 0."

        if self.query_budget:
            assert self.query_budget > 0, "`query_budget` must be greater than 0."

        if self.checkpoint_interval:
            assert (
                self.checkpoint_interval > 0
            ), "`checkpoint_interval` must be greater than 0."

        assert (
            self.num_workers_per_device > 0
        ), "`num_workers_per_device` must be greater than 0."

    @classmethod
    def _add_parser_args(cls, parser):
        """Add listed args to command line parser."""
        default_obj = cls()
        num_ex_group = parser.add_mutually_exclusive_group(required=False)
        num_ex_group.add_argument(
            "--num-examples",
            "-n",
            type=int,
            default=default_obj.num_examples,
            help="The number of examples to process, -1 for entire dataset.",
        )
        num_ex_group.add_argument(
            "--num-successful-examples",
            type=int,
            default=default_obj.num_successful_examples,
            help="The number of successful adversarial examples we want.",
        )
        parser.add_argument(
            "--num-examples-offset",
            "-o",
            type=int,
            required=False,
            default=default_obj.num_examples_offset,
            help="The offset to start at in the dataset.",
        )
        parser.add_argument(
            "--query-budget",
            "-q",
            type=int,
            default=default_obj.query_budget,
            help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
        )
        parser.add_argument(
            "--shuffle",
            action="store_true",
            default=default_obj.shuffle,
            help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
        )
        parser.add_argument(
            "--attack-n",
            action="store_true",
            default=default_obj.attack_n,
            help="Whether to run attack until `n` examples have been attacked (not skipped).",
        )
        parser.add_argument(
            "--checkpoint-dir",
            required=False,
            type=str,
            default=default_obj.checkpoint_dir,
            help="The directory to save checkpoint files.",
        )
        parser.add_argument(
            "--checkpoint-interval",
            required=False,
            type=int,
            default=default_obj.checkpoint_interval,
            help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
        )
        parser.add_argument(
            "--random-seed",
            default=default_obj.random_seed,
            type=int,
            help="Random seed for reproducibility.",
        )
        parser.add_argument(
            "--parallel",
            action="store_true",
            default=default_obj.parallel,
            help="Run attack using multiple GPUs.",
        )
        parser.add_argument(
            "--num-workers-per-device",
            default=default_obj.num_workers_per_device,
            type=int,
            help="Number of worker processes to run per device.",
        )
        parser.add_argument(
            "--log-to-txt",
            nargs="?",
            default=default_obj.log_to_txt,
            const="",
            type=str,
            help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
            "If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
        )
        parser.add_argument(
            "--log-to-csv",
            nargs="?",
            default=default_obj.log_to_csv,
            const="",
            type=str,
            help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
            "If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
        )
        parser.add_argument(
            "--log-summary-to-json",
            nargs="?",
            default=default_obj.log_summary_to_json,
            const="",
            type=str,
            help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
            "If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
        )
        parser.add_argument(
            "--csv-coloring-style",
            default=default_obj.csv_coloring_style,
            type=str,
            help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
            '"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
        )
        parser.add_argument(
            "--log-to-visdom",
            nargs="?",
            default=None,
            const='{"env": "main", "port": 8097, "hostname": "localhost"}',
            type=json.loads,
            help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
            'three keys and their corresponding values: `"env", "port", "hostname"`. '
            'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
        )
        parser.add_argument(
            "--log-to-wandb",
            nargs="?",
            default=None,
            const='{"project": "textattack"}',
            type=json.loads,
            help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
            'key and its corresponding value: `"project"`. '
            'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
        )
        parser.add_argument(
            "--disable-stdout",
            action="store_true",
            default=default_obj.disable_stdout,
            help="Disable logging attack results to stdout",
        )
        parser.add_argument(
            "--silent",
            action="store_true",
            default=default_obj.silent,
            help="Disable all logging",
        )
        parser.add_argument(
            "--enable-advance-metrics",
            action="store_true",
            default=default_obj.enable_advance_metrics,
            help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
        )

        return parser

    @classmethod
    def create_loggers_from_args(cls, args):
        """Creates AttackLogManager from an AttackArgs object."""
        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        # Create logger
        attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)

        # Get current time for file naming
        timestamp = time.strftime("%Y-%m-%d-%H-%M")

        # if '--log-to-txt' specified with arguments
        if args.log_to_txt is not None:
            if args.log_to_txt.lower().endswith(".txt"):
                txt_file_path = args.log_to_txt
            else:
                txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")

            dir_path = os.path.dirname(txt_file_path)
            dir_path = dir_path if dir_path else "."
            if not os.path.exists(dir_path):
                os.makedirs(os.path.dirname(txt_file_path))

            color_method = "file"
            attack_log_manager.add_output_file(txt_file_path, color_method)

        # if '--log-to-csv' specified with arguments
        if args.log_to_csv is not None:
            if args.log_to_csv.lower().endswith(".csv"):
                csv_file_path = args.log_to_csv
            else:
                csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")

            dir_path = os.path.dirname(csv_file_path)
            dir_path = dir_path if dir_path else "."
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)

            color_method = (
                None if args.csv_coloring_style == "plain" else args.csv_coloring_style
            )
            attack_log_manager.add_output_csv(csv_file_path, color_method)

        # if '--log-summary-to-json' specified with arguments
        if args.log_summary_to_json is not None:
            if args.log_summary_to_json.lower().endswith(".json"):
                summary_json_file_path = args.log_summary_to_json
            else:
                summary_json_file_path = os.path.join(
                    args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
                )

            dir_path = os.path.dirname(summary_json_file_path)
            dir_path = dir_path if dir_path else "."
            if not os.path.exists(dir_path):
                os.makedirs(os.path.dirname(summary_json_file_path))

            attack_log_manager.add_output_summary_json(summary_json_file_path)

        # Visdom
        if args.log_to_visdom is not None:
            attack_log_manager.enable_visdom(**args.log_to_visdom)

        # Weights & Biases
        if args.log_to_wandb is not None:
            attack_log_manager.enable_wandb(**args.log_to_wandb)

        # Stdout
        if not args.disable_stdout and not sys.stdout.isatty():
            attack_log_manager.disable_color()
        elif not args.disable_stdout:
            attack_log_manager.enable_stdout()

        return attack_log_manager


@dataclass
class _CommandLineAttackArgs:
    """Attack args for command line execution. This requires more arguments to
    create ``Attack`` object as specified.

    Args:
        transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
            Name of transformation to use.
        constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
            List of names of constraints to use.
        goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
            Name of goal function to use.
        search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
            Name of search method to use.
        attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
            Name of attack recipe to use.
            .. note::
                Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
        attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
            Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
            .. note::
                If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
        interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
            If `True`, carry attack in interactive mode.
        parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
            If `True`, attack in parallel.
        model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
            The batch size for making queries to the victim model.
        model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
            The maximum number of items to keep in the model results cache at once.
        constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
            The maximum number of items to keep in the constraints cache at once.
    """

    transformation: str = "word-swap-embedding"
    constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
    goal_function: str = "untargeted-classification"
    search_method: str = "greedy-word-wir"
    attack_recipe: str = None
    attack_from_file: str = None
    interactive: bool = False
    parallel: bool = False
    model_batch_size: int = 32
    model_cache_size: int = 2**18
    constraint_cache_size: int = 2**18

    @classmethod
    def _add_parser_args(cls, parser):
        """Add listed args to command line parser."""
        default_obj = cls()
        transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
            WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
        )
        parser.add_argument(
            "--transformation",
            type=str,
            required=False,
            default=default_obj.transformation,
            help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
            + str(transformation_names),
        )
        parser.add_argument(
            "--constraints",
            type=str,
            required=False,
            nargs="*",
            default=default_obj.constraints,
            help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
            + str(CONSTRAINT_CLASS_NAMES.keys()),
        )
        goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
        parser.add_argument(
            "--goal-function",
            "-g",
            default=default_obj.goal_function,
            help=f"The goal function to use. choices: {goal_function_choices}",
        )
        attack_group = parser.add_mutually_exclusive_group(required=False)
        search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
        attack_group.add_argument(
            "--search-method",
            "--search",
            "-s",
            type=str,
            required=False,
            default=default_obj.search_method,
            help=f"The search method to use. choices: {search_choices}",
        )
        attack_group.add_argument(
            "--attack-recipe",
            "--recipe",
            "-r",
            type=str,
            required=False,
            default=default_obj.attack_recipe,
            help="full attack recipe (overrides provided goal function, transformation & constraints)",
            choices=ATTACK_RECIPE_NAMES.keys(),
        )
        attack_group.add_argument(
            "--attack-from-file",
            type=str,
            required=False,
            default=default_obj.attack_from_file,
            help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
        )
        parser.add_argument(
            "--interactive",
            action="store_true",
            default=default_obj.interactive,
            help="Whether to run attacks interactively.",
        )
        parser.add_argument(
            "--model-batch-size",
            type=int,
            default=default_obj.model_batch_size,
            help="The batch size for making calls to the model.",
        )
        parser.add_argument(
            "--model-cache-size",
            type=int,
            default=default_obj.model_cache_size,
            help="The maximum number of items to keep in the model results cache at once.",
        )
        parser.add_argument(
            "--constraint-cache-size",
            type=int,
            default=default_obj.constraint_cache_size,
            help="The maximum number of items to keep in the constraints cache at once.",
        )

        return parser

    @classmethod
    def _create_transformation_from_args(cls, args, model_wrapper):
        """Create `Transformation` based on provided `args` and
        `model_wrapper`."""

        transformation_name = args.transformation
        if ARGS_SPLIT_TOKEN in transformation_name:
            transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)

            if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
                transformation = eval(
                    f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
                )
            elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
                transformation = eval(
                    f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
                )
            else:
                raise ValueError(
                    f"Error: unsupported transformation {transformation_name}"
                )
        else:
            if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
                transformation = eval(
                    f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
                )
            elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
                transformation = eval(
                    f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
                )
            else:
                raise ValueError(
                    f"Error: unsupported transformation {transformation_name}"
                )
        return transformation

    @classmethod
    def _create_goal_function_from_args(cls, args, model_wrapper):
        """Create `GoalFunction` based on provided `args` and
        `model_wrapper`."""

        goal_function = args.goal_function
        if ARGS_SPLIT_TOKEN in goal_function:
            goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
            if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
                raise ValueError(
                    f"Error: unsupported goal_function {goal_function_name}"
                )
            goal_function = eval(
                f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
            )
        elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
            goal_function = eval(
                f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
            )
        else:
            raise ValueError(f"Error: unsupported goal_function {goal_function}")
        if args.query_budget:
            goal_function.query_budget = args.query_budget
        goal_function.model_cache_size = args.model_cache_size
        goal_function.batch_size = args.model_batch_size
        return goal_function

    @classmethod
    def _create_constraints_from_args(cls, args):
        """Create list of `Constraints` based on provided `args`."""

        if not args.constraints:
            return []

        _constraints = []
        for constraint in args.constraints:
            if ARGS_SPLIT_TOKEN in constraint:
                constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
                if constraint_name not in CONSTRAINT_CLASS_NAMES:
                    raise ValueError(f"Error: unsupported constraint {constraint_name}")
                _constraints.append(
                    eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
                )
            elif constraint in CONSTRAINT_CLASS_NAMES:
                _constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
            else:
                raise ValueError(f"Error: unsupported constraint {constraint}")

        return _constraints

    @classmethod
    def _create_attack_from_args(cls, args, model_wrapper):
        """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
        ``Attack`` object."""

        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        if args.attack_recipe:
            if ARGS_SPLIT_TOKEN in args.attack_recipe:
                recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
                if recipe_name not in ATTACK_RECIPE_NAMES:
                    raise ValueError(f"Error: unsupported recipe {recipe_name}")
                recipe = eval(
                    f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
                )
            elif args.attack_recipe in ATTACK_RECIPE_NAMES:
                recipe = eval(
                    f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
                )
            else:
                raise ValueError(f"Invalid recipe {args.attack_recipe}")
            if args.query_budget:
                recipe.goal_function.query_budget = args.query_budget
            recipe.goal_function.model_cache_size = args.model_cache_size
            recipe.constraint_cache_size = args.constraint_cache_size
            return recipe
        elif args.attack_from_file:
            if ARGS_SPLIT_TOKEN in args.attack_from_file:
                attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
            else:
                attack_file, attack_name = args.attack_from_file, "attack"
            attack_module = load_module_from_file(attack_file)
            if not hasattr(attack_module, attack_name):
                raise ValueError(
                    f"Loaded `{attack_file}` but could not find `{attack_name}`."
                )
            attack_func = getattr(attack_module, attack_name)
            return attack_func(model_wrapper)
        else:
            goal_function = cls._create_goal_function_from_args(args, model_wrapper)
            transformation = cls._create_transformation_from_args(args, model_wrapper)
            constraints = cls._create_constraints_from_args(args)
            if ARGS_SPLIT_TOKEN in args.search_method:
                search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
                if search_name not in SEARCH_METHOD_CLASS_NAMES:
                    raise ValueError(f"Error: unsupported search {search_name}")
                search_method = eval(
                    f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
                )
            elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
                search_method = eval(
                    f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
                )
            else:
                raise ValueError(f"Error: unsupported attack {args.search_method}")

        return Attack(
            goal_function,
            constraints,
            transformation,
            search_method,
            constraint_cache_size=args.constraint_cache_size,
        )


# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
    @classmethod
    def _add_parser_args(cls, parser):
        """Add listed args to command line parser."""
        parser = ModelArgs._add_parser_args(parser)
        parser = DatasetArgs._add_parser_args(parser)
        parser = _CommandLineAttackArgs._add_parser_args(parser)
        parser = AttackArgs._add_parser_args(parser)
        return parser