|
""" |
|
Beam Search |
|
=============== |
|
|
|
""" |
|
import numpy as np |
|
|
|
from textattack.goal_function_results import GoalFunctionResultStatus |
|
from textattack.search_methods import SearchMethod |
|
|
|
|
|
class BeamSearch(SearchMethod): |
|
"""An attack that maintains a beam of the `beam_width` highest scoring |
|
AttackedTexts, greedily updating the beam with the highest scoring |
|
transformations from the current beam. |
|
|
|
Args: |
|
goal_function: A function for determining how well a perturbation is doing at achieving the attack's goal. |
|
transformation: The type of transformation. |
|
beam_width (int): the number of candidates to retain at each step |
|
""" |
|
|
|
def __init__(self, beam_width=8): |
|
self.beam_width = beam_width |
|
|
|
def perform_search(self, initial_result): |
|
beam = [initial_result.attacked_text] |
|
best_result = initial_result |
|
while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: |
|
potential_next_beam = [] |
|
for text in beam: |
|
transformations = self.get_transformations( |
|
text, original_text=initial_result.attacked_text |
|
) |
|
potential_next_beam += transformations |
|
|
|
if len(potential_next_beam) == 0: |
|
|
|
return best_result |
|
results, search_over = self.get_goal_results(potential_next_beam) |
|
scores = np.array([r.score for r in results]) |
|
best_result = results[scores.argmax()] |
|
if search_over: |
|
return best_result |
|
|
|
|
|
|
|
best_indices = (-scores).argsort()[: self.beam_width] |
|
beam = [potential_next_beam[i] for i in best_indices] |
|
|
|
return best_result |
|
|
|
@property |
|
def is_black_box(self): |
|
return True |
|
|
|
def extra_repr_keys(self): |
|
return ["beam_width"] |
|
|