File size: 9,113 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
""".. _goal_function:

GoalFunction Class
===========================================================
"""


from abc import ABC, abstractmethod

import lru
import numpy as np
import torch

from textattack.goal_function_results.goal_function_result import (
    GoalFunctionResultStatus,
)
from textattack.shared import validators
from textattack.shared.utils import ReprMixin


class GoalFunction(ReprMixin, ABC):
    """Evaluates how well a perturbed attacked_text object is achieving a
    specified goal.

    Args:
        model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
            The victim model to attack.
        maximizable(:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the goal function is maximizable, as opposed to a boolean result of success or failure.
        query_budget (:obj:`float`, `optional`, defaults to :obj:`float("in")`):
            The maximum number of model queries allowed.
        model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**20`):
            The maximum number of items to keep in the model results cache at once.
    """

    def __init__(
        self,
        model_wrapper,
        maximizable=False,
        use_cache=True,
        query_budget=float("inf"),
        model_batch_size=32,
        model_cache_size=2**20,
    ):
        validators.validate_model_goal_function_compatibility(
            self.__class__, model_wrapper.model.__class__
        )
        self.model = model_wrapper
        self.maximizable = maximizable
        self.use_cache = use_cache
        self.query_budget = query_budget
        self.batch_size = model_batch_size
        if self.use_cache:
            self._call_model_cache = lru.LRU(model_cache_size)
        else:
            self._call_model_cache = None

    def clear_cache(self):
        if self.use_cache:
            self._call_model_cache.clear()

    def init_attack_example(self, attacked_text, ground_truth_output):
        """Called before attacking ``attacked_text`` to 'reset' the goal
        function and set properties for this example."""
        self.initial_attacked_text = attacked_text
        self.ground_truth_output = ground_truth_output
        self.num_queries = 0
        result, _ = self.get_result(attacked_text, check_skip=True)
        return result, _

    def get_output(self, attacked_text):
        """Returns output for display based on the result of calling the
        model."""
        return self._get_displayed_output(self._call_model([attacked_text])[0])

    def get_result(self, attacked_text, **kwargs):
        """A helper method that queries ``self.get_results`` with a single
        ``AttackedText`` object."""
        results, search_over = self.get_results([attacked_text], **kwargs)
        result = results[0] if len(results) else None
        return result, search_over

    def get_results(self, attacked_text_list, check_skip=False):
        """For each attacked_text object in attacked_text_list, returns a
        result consisting of whether or not the goal has been achieved, the
        output for display purposes, and a score.

        Additionally returns whether the search is over due to the query
        budget.
        """
        results = []
        if self.query_budget < float("inf"):
            queries_left = self.query_budget - self.num_queries
            attacked_text_list = attacked_text_list[:queries_left]
        self.num_queries += len(attacked_text_list)
        model_outputs = self._call_model(attacked_text_list)
        for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
            displayed_output = self._get_displayed_output(raw_output)
            goal_status = self._get_goal_status(
                raw_output, attacked_text, check_skip=check_skip
            )
            goal_function_score = self._get_score(raw_output, attacked_text)
            results.append(
                self._goal_function_result_type()(
                    attacked_text,
                    raw_output,
                    displayed_output,
                    goal_status,
                    goal_function_score,
                    self.num_queries,
                    self.ground_truth_output,
                )
            )
        return results, self.num_queries == self.query_budget

    def _get_goal_status(self, model_output, attacked_text, check_skip=False):
        should_skip = check_skip and self._should_skip(model_output, attacked_text)
        if should_skip:
            return GoalFunctionResultStatus.SKIPPED
        if self.maximizable:
            return GoalFunctionResultStatus.MAXIMIZING
        if self._is_goal_complete(model_output, attacked_text):
            return GoalFunctionResultStatus.SUCCEEDED
        return GoalFunctionResultStatus.SEARCHING

    @abstractmethod
    def _is_goal_complete(self, model_output, attacked_text):
        raise NotImplementedError()

    def _should_skip(self, model_output, attacked_text):
        return self._is_goal_complete(model_output, attacked_text)

    @abstractmethod
    def _get_score(self, model_output, attacked_text):
        raise NotImplementedError()

    def _get_displayed_output(self, raw_output):
        return raw_output

    @abstractmethod
    def _goal_function_result_type(self):
        """Returns the class of this goal function's results."""
        raise NotImplementedError()

    @abstractmethod
    def _process_model_outputs(self, inputs, outputs):
        """Processes and validates a list of model outputs.

        This is a task-dependent operation. For example, classification
        outputs need to make sure they have a softmax applied.
        """
        raise NotImplementedError()

    def _call_model_uncached(self, attacked_text_list):
        """Queries model and returns outputs for a list of AttackedText
        objects."""
        if not len(attacked_text_list):
            return []

        inputs = [at.tokenizer_input for at in attacked_text_list]
        outputs = []
        i = 0
        while i < len(inputs):
            batch = inputs[i : i + self.batch_size]
            batch_preds = self.model(batch)

            # Some seq-to-seq models will return a single string as a prediction
            # for a single-string list. Wrap these in a list.
            if isinstance(batch_preds, str):
                batch_preds = [batch_preds]

            # Get PyTorch tensors off of other devices.
            if isinstance(batch_preds, torch.Tensor):
                batch_preds = batch_preds.cpu()

            if isinstance(batch_preds, list):
                outputs.extend(batch_preds)
            elif isinstance(batch_preds, np.ndarray):
                outputs.append(torch.tensor(batch_preds))
            else:
                outputs.append(batch_preds)
            i += self.batch_size

        if isinstance(outputs[0], torch.Tensor):
            outputs = torch.cat(outputs, dim=0)

        assert len(inputs) == len(
            outputs
        ), f"Got {len(outputs)} outputs for {len(inputs)} inputs"

        return self._process_model_outputs(attacked_text_list, outputs)

    def _call_model(self, attacked_text_list):
        """Gets predictions for a list of ``AttackedText`` objects.

        Gets prediction from cache if possible. If prediction is not in
        the cache, queries model and stores prediction in cache.
        """
        if not self.use_cache:
            return self._call_model_uncached(attacked_text_list)
        else:
            uncached_list = []
            for text in attacked_text_list:
                if text in self._call_model_cache:
                    # Re-write value in cache. This moves the key to the top of the
                    # LRU cache and prevents the unlikely event that the text
                    # is overwritten when we store the inputs from `uncached_list`.
                    self._call_model_cache[text] = self._call_model_cache[text]
                else:
                    uncached_list.append(text)
            uncached_list = [
                text
                for text in attacked_text_list
                if text not in self._call_model_cache
            ]
            outputs = self._call_model_uncached(uncached_list)
            for text, output in zip(uncached_list, outputs):
                self._call_model_cache[text] = output
            all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
            return all_outputs

    def extra_repr_keys(self):
        attrs = []
        if self.query_budget < float("inf"):
            attrs.append("query_budget")
        if self.maximizable:
            attrs.append("maximizable")
        return attrs

    def __getstate__(self):
        state = self.__dict__.copy()
        if self.use_cache:
            state["_call_model_cache"] = self._call_model_cache.get_size()
        return state

    def __setstate__(self, state):
        self.__dict__ = state
        if self.use_cache:
            self._call_model_cache = lru.LRU(state["_call_model_cache"])