DCWIR-Offcial-Demo / textattack /search_methods /population_based_search.py
PFEemp2024's picture
solving GPU error for previous version
4a1df2e
"""
Population based Search
==========================
"""
from abc import ABC, abstractmethod
from textattack.search_methods import SearchMethod
class PopulationBasedSearch(SearchMethod, ABC):
"""Abstract base class for population-based search methods.
Examples include: genetic algorithm, particle swarm optimization
"""
def _check_constraints(self, transformed_text, current_text, original_text):
"""Check if `transformted_text` still passes the constraints with
respect to `current_text` and `original_text`.
This method is required because of a lot population-based methods does their own transformations apart from
the actual `transformation`. Examples include `crossover` from `GeneticAlgorithm` and `move` from `ParticleSwarmOptimization`.
Args:
transformed_text (AttackedText): Resulting text after transformation
current_text (AttackedText): Recent text from which `transformed_text` was produced from.
original_text (AttackedText): Original text
Returns
`True` if constraints satisfied and `False` if otherwise.
"""
filtered = self.filter_transformations(
[transformed_text], current_text, original_text=original_text
)
return True if filtered else False
@abstractmethod
def _perturb(self, pop_member, original_result, **kwargs):
"""Perturb `pop_member` in-place.
Must be overridden by specific population-based method
Args:
pop_member (PopulationMember): Population member to perturb\
original_result (GoalFunctionResult): Result for original text. Often needed for constraint checking.
Returns
`True` if perturbation occured. `False` if not.
"""
raise NotImplementedError()
@abstractmethod
def _initialize_population(self, initial_result, pop_size):
"""
Initialize a population of size `pop_size` with `initial_result`
Args:
initial_result (GoalFunctionResult): Original text
pop_size (int): size of population
Returns:
population as `list[PopulationMember]`
"""
raise NotImplementedError
class PopulationMember:
"""Represent a single member of population."""
def __init__(self, attacked_text, result=None, attributes={}, **kwargs):
self.attacked_text = attacked_text
self.result = result
self.attributes = attributes
for key, value in kwargs.items():
setattr(self, key, value)
@property
def score(self):
if not self.result:
raise ValueError(
"Result must be obtained for PopulationMember to get its score."
)
return self.result.score
@property
def words(self):
return self.attacked_text.words
@property
def num_words(self):
return self.attacked_text.num_words