File size: 3,722 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
import os
import random
import numpy as np
import torch
import textattack
device = os.environ.get(
"TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
def html_style_from_dict(style_dict):
"""Turns.
{ 'color': 'red', 'height': '100px'}
into
style: "color: red; height: 100px"
"""
style_str = ""
for key in style_dict:
style_str += key + ": " + style_dict[key] + ";"
return 'style="{}"'.format(style_str)
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
# Stylize the container div.
if style_dict:
table_html = "<div {}>".format(html_style_from_dict(style_dict))
else:
table_html = "<div>"
# Print the title string.
if title:
table_html += "<h1>{}</h1>".format(title)
# Construct each row as HTML.
table_html = '<table class="table">'
if header:
table_html += "<tr>"
for element in header:
table_html += "<th>"
table_html += str(element)
table_html += "</th>"
table_html += "</tr>"
for row in rows:
table_html += "<tr>"
for element in row:
table_html += "<td>"
table_html += str(element)
table_html += "</td>"
table_html += "</tr>"
# Close the table and print to screen.
table_html += "</table></div>"
return table_html
def get_textattack_model_num_labels(model_name, model_path):
"""Reads `train_args.json` and gets the number of labels for a trained
model, if present."""
model_cache_path = textattack.shared.utils.download_from_s3(model_path)
train_args_path = os.path.join(model_cache_path, "train_args.json")
if not os.path.exists(train_args_path):
textattack.shared.logger.warn(
f"train_args.json not found in model path {model_path}. Defaulting to 2 labels."
)
return 2
else:
args = json.loads(open(train_args_path).read())
return args.get("num_labels", 2)
def load_textattack_model_from_path(model_name, model_path):
"""Loads a pre-trained TextAttack model from its name and path.
For example, model_name "lstm-yelp" and model path
"models/classification/lstm/yelp".
"""
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
if model_name.startswith("lstm"):
num_labels = get_textattack_model_num_labels(model_name, model_path)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
)
model = textattack.models.helpers.LSTMForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("cnn"):
num_labels = get_textattack_model_num_labels(model_name, model_path)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
)
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else:
raise ValueError(f"Unknown textattack model {model_path}")
return model
def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
def hashable(key):
try:
hash(key)
return True
except TypeError:
return False
def sigmoid(n):
return 1 / (1 + np.exp(-n))
GLOBAL_OBJECTS = {}
ARGS_SPLIT_TOKEN = "^"
|