import torch | |
from textattack.models.helpers import LSTMForClassification | |
from textattack.models.wrappers import PyTorchModelWrapper | |
model = LSTMForClassification.from_pretrained('/root/attack/outputs/2024-02-23-10-43-21-925203/best_model') | |
model = PyTorchModelWrapper(model, model.tokenizer) | |
from textattack.transformations import WordSwapMaskedLM | |
from textattack.search_methods import GreedyWordSwapWIR | |
from textattack.datasets import Dataset | |
transformation = WordSwapMaskedLM() | |
search_method = GreedyWordSwapWIR('unk') | |
from textattack import Attack, AttackArgs, Attacker | |
from textattack.goal_functions import UntargetedClassification | |
dataset = Dataset([('I like play basketball.', 1)]) | |
goal_function = UntargetedClassification(model) | |
attack_args = AttackArgs(query_budget=10000) | |
attack = Attack(goal_function, [], transformation, search_method) | |
attacker = Attacker(attack, dataset) | |
attacker.attack_dataset() | |