Spaces:
Runtime error
Runtime error
""" | |
Word CNN for Classification | |
--------------------------------------------------------------------- | |
""" | |
import json | |
import os | |
import torch | |
from torch import nn as nn | |
from torch.nn import functional as F | |
import textattack | |
from textattack.model_args import TEXTATTACK_MODELS | |
from textattack.models.helpers import GloveEmbeddingLayer | |
from textattack.models.helpers.utils import load_cached_state_dict | |
from textattack.shared import utils | |
class WordCNNForClassification(nn.Module): | |
"""A convolutional neural network for text classification. | |
We use different versions of this network to pretrain models for | |
text classification. | |
""" | |
def __init__( | |
self, | |
hidden_size=150, | |
dropout=0.3, | |
num_labels=2, | |
max_seq_length=128, | |
model_path=None, | |
emb_layer_trainable=True, | |
): | |
super().__init__() | |
self._config = { | |
"architectures": "WordCNNForClassification", | |
"hidden_size": hidden_size, | |
"dropout": dropout, | |
"num_labels": num_labels, | |
"max_seq_length": max_seq_length, | |
"model_path": model_path, | |
"emb_layer_trainable": emb_layer_trainable, | |
} | |
self.drop = nn.Dropout(dropout) | |
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) | |
self.word2id = self.emb_layer.word2id | |
self.encoder = CNNTextLayer( | |
self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size | |
) | |
d_out = 3 * hidden_size | |
self.out = nn.Linear(d_out, num_labels) | |
self.tokenizer = textattack.models.tokenizers.GloveTokenizer( | |
word_id_map=self.word2id, | |
unk_token_id=self.emb_layer.oovid, | |
pad_token_id=self.emb_layer.padid, | |
max_length=max_seq_length, | |
) | |
if model_path is not None: | |
self.load_from_disk(model_path) | |
self.eval() | |
def load_from_disk(self, model_path): | |
# TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`. | |
import warnings | |
warnings.warn( | |
"`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
self.load_state_dict(load_cached_state_dict(model_path)) | |
self.eval() | |
def save_pretrained(self, output_path): | |
if not os.path.exists(output_path): | |
os.makedirs(output_path) | |
state_dict = {k: v.cpu() for k, v in self.state_dict().items()} | |
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin")) | |
with open(os.path.join(output_path, "config.json"), "w") as f: | |
json.dump(self._config, f) | |
def from_pretrained(cls, name_or_path): | |
"""Load trained Word CNN model by name or from path. | |
Args: | |
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`. | |
Returns: | |
:class:`~textattack.models.helpers.WordCNNForClassification` model | |
""" | |
if name_or_path in TEXTATTACK_MODELS: | |
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path]) | |
else: | |
path = name_or_path | |
config_path = os.path.join(path, "config.json") | |
if os.path.exists(config_path): | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
else: | |
# Default config | |
config = { | |
"architectures": "WordCNNForClassification", | |
"hidden_size": 150, | |
"dropout": 0.3, | |
"num_labels": 2, | |
"max_seq_length": 128, | |
"model_path": None, | |
"emb_layer_trainable": True, | |
} | |
del config["architectures"] | |
model = cls(**config) | |
state_dict = load_cached_state_dict(path) | |
model.load_state_dict(state_dict) | |
return model | |
def forward(self, _input): | |
emb = self.emb_layer(_input) | |
emb = self.drop(emb) | |
output = self.encoder(emb) | |
output = self.drop(output) | |
pred = self.out(output) | |
return pred | |
def get_input_embeddings(self): | |
return self.emb_layer.embedding | |
class CNNTextLayer(nn.Module): | |
def __init__(self, n_in, widths=[3, 4, 5], filters=100): | |
super().__init__() | |
Ci = 1 | |
Co = filters | |
h = n_in | |
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths]) | |
def forward(self, x): | |
x = x.unsqueeze(1) # (batch, Ci, len, d) | |
x = [ | |
F.relu(conv(x)).squeeze(3) for conv in self.convs1 | |
] # [(batch, Co, len), ...] | |
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N,Co), ...] | |
x = torch.cat(x, 1) | |
return x | |