anonymous8/RPD-Demo
initial commit
4943752
raw
history blame
2.04 kB
"""
Beam Search
===============
"""
import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
class BeamSearch(SearchMethod):
"""An attack that maintains a beam of the `beam_width` highest scoring
AttackedTexts, greedily updating the beam with the highest scoring
transformations from the current beam.
Args:
goal_function: A function for determining how well a perturbation is doing at achieving the attack's goal.
transformation: The type of transformation.
beam_width (int): the number of candidates to retain at each step
"""
def __init__(self, beam_width=8):
self.beam_width = beam_width
def perform_search(self, initial_result):
beam = [initial_result.attacked_text]
best_result = initial_result
while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
potential_next_beam = []
for text in beam:
transformations = self.get_transformations(
text, original_text=initial_result.attacked_text
)
potential_next_beam += transformations
if len(potential_next_beam) == 0:
# If we did not find any possible perturbations, give up.
return best_result
results, search_over = self.get_goal_results(potential_next_beam)
scores = np.array([r.score for r in results])
best_result = results[scores.argmax()]
if search_over:
return best_result
# Refill the beam. This works by sorting the scores
# in descending order and filling the beam from there.
best_indices = (-scores).argsort()[: self.beam_width]
beam = [potential_next_beam[i] for i in best_indices]
return best_result
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return ["beam_width"]