Spaces:
Runtime error
Runtime error
""" | |
Attack Logs to WandB | |
======================== | |
""" | |
from textattack.shared.utils import LazyLoader, html_table_from_rows | |
from .logger import Logger | |
class WeightsAndBiasesLogger(Logger): | |
"""Logs attack results to Weights & Biases.""" | |
def __init__(self, **kwargs): | |
global wandb | |
wandb = LazyLoader("wandb", globals(), "wandb") | |
wandb.init(**kwargs) | |
self.kwargs = kwargs | |
self.project_name = wandb.run.project_name() | |
self._result_table_rows = [] | |
def __setstate__(self, state): | |
global wandb | |
wandb = LazyLoader("wandb", globals(), "wandb") | |
self.__dict__ = state | |
wandb.init(resume=True, **self.kwargs) | |
def log_summary_rows(self, rows, title, window_id): | |
table = wandb.Table(columns=["Attack Results", ""]) | |
for row in rows: | |
if isinstance(row[1], str): | |
try: | |
row[1] = row[1].replace("%", "") | |
row[1] = float(row[1]) | |
except ValueError: | |
raise ValueError( | |
f'Unable to convert row value "{row[1]}" for Attack Result "{row[0]}" into float' | |
) | |
table.add_data(*row) | |
metric_name, metric_score = row | |
wandb.run.summary[metric_name] = metric_score | |
wandb.log({"attack_params": table}) | |
def _log_result_table(self): | |
"""Weights & Biases doesn't have a feature to automatically aggregate | |
results across timesteps and display the full table. | |
Therefore, we have to do it manually. | |
""" | |
result_table = html_table_from_rows( | |
self._result_table_rows, header=["", "Original Input", "Perturbed Input"] | |
) | |
wandb.log({"results": wandb.Html(result_table)}) | |
def log_attack_result(self, result): | |
original_text_colored, perturbed_text_colored = result.diff_color( | |
color_method="html" | |
) | |
result_num = len(self._result_table_rows) | |
self._result_table_rows.append( | |
[ | |
f"<b>Result {result_num}</b>", | |
original_text_colored, | |
perturbed_text_colored, | |
] | |
) | |
result_diff_table = html_table_from_rows( | |
[[original_text_colored, perturbed_text_colored]] | |
) | |
result_diff_table = wandb.Html(result_diff_table) | |
wandb.log( | |
{ | |
"result": result_diff_table, | |
"original_output": result.original_result.output, | |
"perturbed_output": result.perturbed_result.output, | |
} | |
) | |
self._log_result_table() | |
def log_sep(self): | |
self.fout.write("-" * 90 + "\n") | |