File size: 3,722 Bytes
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import os
import random

import numpy as np
import torch

import textattack

device = os.environ.get(
    "TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu")
)


def html_style_from_dict(style_dict):
    """Turns.

        { 'color': 'red', 'height': '100px'}

    into
        style: "color: red; height: 100px"
    """
    style_str = ""
    for key in style_dict:
        style_str += key + ": " + style_dict[key] + ";"
    return 'style="{}"'.format(style_str)


def html_table_from_rows(rows, title=None, header=None, style_dict=None):
    # Stylize the container div.
    if style_dict:
        table_html = "<div {}>".format(html_style_from_dict(style_dict))
    else:
        table_html = "<div>"
    # Print the title string.
    if title:
        table_html += "<h1>{}</h1>".format(title)

    # Construct each row as HTML.
    table_html = '<table class="table">'
    if header:
        table_html += "<tr>"
        for element in header:
            table_html += "<th>"
            table_html += str(element)
            table_html += "</th>"
        table_html += "</tr>"
    for row in rows:
        table_html += "<tr>"
        for element in row:
            table_html += "<td>"
            table_html += str(element)
            table_html += "</td>"
        table_html += "</tr>"

    # Close the table and print to screen.
    table_html += "</table></div>"

    return table_html


def get_textattack_model_num_labels(model_name, model_path):
    """Reads `train_args.json` and gets the number of labels for a trained
    model, if present."""
    model_cache_path = textattack.shared.utils.download_from_s3(model_path)
    train_args_path = os.path.join(model_cache_path, "train_args.json")
    if not os.path.exists(train_args_path):
        textattack.shared.logger.warn(
            f"train_args.json not found in model path {model_path}. Defaulting to 2 labels."
        )
        return 2
    else:
        args = json.loads(open(train_args_path).read())
        return args.get("num_labels", 2)


def load_textattack_model_from_path(model_name, model_path):
    """Loads a pre-trained TextAttack model from its name and path.

    For example, model_name "lstm-yelp" and model path
    "models/classification/lstm/yelp".
    """

    colored_model_name = textattack.shared.utils.color_text(
        model_name, color="blue", method="ansi"
    )
    if model_name.startswith("lstm"):
        num_labels = get_textattack_model_num_labels(model_name, model_path)
        textattack.shared.logger.info(
            f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
        )
        model = textattack.models.helpers.LSTMForClassification(
            model_path=model_path, num_labels=num_labels
        )
    elif model_name.startswith("cnn"):
        num_labels = get_textattack_model_num_labels(model_name, model_path)
        textattack.shared.logger.info(
            f"Loading pre-trained TextAttack CNN: {colored_model_name}"
        )
        model = textattack.models.helpers.WordCNNForClassification(
            model_path=model_path, num_labels=num_labels
        )
    elif model_name.startswith("t5"):
        model = textattack.models.helpers.T5ForTextToText(model_path)
    else:
        raise ValueError(f"Unknown textattack model {model_path}")
    return model


def set_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)


def hashable(key):
    try:
        hash(key)
        return True
    except TypeError:
        return False


def sigmoid(n):
    return 1 / (1 + np.exp(-n))


GLOBAL_OBJECTS = {}
ARGS_SPLIT_TOKEN = "^"