Spaces:
Running
Running
""" | |
Population based Search | |
========================== | |
""" | |
from abc import ABC, abstractmethod | |
from textattack.search_methods import SearchMethod | |
class PopulationBasedSearch(SearchMethod, ABC): | |
"""Abstract base class for population-based search methods. | |
Examples include: genetic algorithm, particle swarm optimization | |
""" | |
def _check_constraints(self, transformed_text, current_text, original_text): | |
"""Check if `transformted_text` still passes the constraints with | |
respect to `current_text` and `original_text`. | |
This method is required because of a lot population-based methods does their own transformations apart from | |
the actual `transformation`. Examples include `crossover` from `GeneticAlgorithm` and `move` from `ParticleSwarmOptimization`. | |
Args: | |
transformed_text (AttackedText): Resulting text after transformation | |
current_text (AttackedText): Recent text from which `transformed_text` was produced from. | |
original_text (AttackedText): Original text | |
Returns | |
`True` if constraints satisfied and `False` if otherwise. | |
""" | |
filtered = self.filter_transformations( | |
[transformed_text], current_text, original_text=original_text | |
) | |
return True if filtered else False | |
def _perturb(self, pop_member, original_result, **kwargs): | |
"""Perturb `pop_member` in-place. | |
Must be overridden by specific population-based method | |
Args: | |
pop_member (PopulationMember): Population member to perturb\ | |
original_result (GoalFunctionResult): Result for original text. Often needed for constraint checking. | |
Returns | |
`True` if perturbation occured. `False` if not. | |
""" | |
raise NotImplementedError() | |
def _initialize_population(self, initial_result, pop_size): | |
""" | |
Initialize a population of size `pop_size` with `initial_result` | |
Args: | |
initial_result (GoalFunctionResult): Original text | |
pop_size (int): size of population | |
Returns: | |
population as `list[PopulationMember]` | |
""" | |
raise NotImplementedError | |
class PopulationMember: | |
"""Represent a single member of population.""" | |
def __init__(self, attacked_text, result=None, attributes={}, **kwargs): | |
self.attacked_text = attacked_text | |
self.result = result | |
self.attributes = attributes | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
def score(self): | |
if not self.result: | |
raise ValueError( | |
"Result must be obtained for PopulationMember to get its score." | |
) | |
return self.result.score | |
def words(self): | |
return self.attacked_text.words | |
def num_words(self): | |
return self.attacked_text.num_words | |