|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for generating/preprocessing data for adversarial text models.""" |
|
|
|
import operator |
|
import os |
|
import random |
|
import re |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
EOS_TOKEN = '</s>' |
|
|
|
|
|
|
|
ALL_SA = 'all_sa.tfrecords' |
|
TRAIN_SA = 'train_sa.tfrecords' |
|
TEST_SA = 'test_sa.tfrecords' |
|
|
|
ALL_LM = 'all_lm.tfrecords' |
|
TRAIN_LM = 'train_lm.tfrecords' |
|
TEST_LM = 'test_lm.tfrecords' |
|
|
|
TRAIN_CLASS = 'train_classification.tfrecords' |
|
TEST_CLASS = 'test_classification.tfrecords' |
|
VALID_CLASS = 'validate_classification.tfrecords' |
|
|
|
TRAIN_REV_LM = 'train_reverse_lm.tfrecords' |
|
TEST_REV_LM = 'test_reverse_lm.tfrecords' |
|
|
|
TRAIN_BD_CLASS = 'train_bidir_classification.tfrecords' |
|
TEST_BD_CLASS = 'test_bidir_classification.tfrecords' |
|
VALID_BD_CLASS = 'validate_bidir_classification.tfrecords' |
|
|
|
|
|
class ShufflingTFRecordWriter(object): |
|
"""Thin wrapper around TFRecordWriter that shuffles records.""" |
|
|
|
def __init__(self, path): |
|
self._path = path |
|
self._records = [] |
|
self._closed = False |
|
|
|
def write(self, record): |
|
assert not self._closed |
|
self._records.append(record) |
|
|
|
def close(self): |
|
assert not self._closed |
|
random.shuffle(self._records) |
|
with tf.python_io.TFRecordWriter(self._path) as f: |
|
for record in self._records: |
|
f.write(record) |
|
self._closed = True |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, unused_type, unused_value, unused_traceback): |
|
self.close() |
|
|
|
|
|
class Timestep(object): |
|
"""Represents a single timestep in a SequenceWrapper.""" |
|
|
|
def __init__(self, token, label, weight, multivalent_tokens=False): |
|
"""Constructs Timestep from empty Features.""" |
|
self._token = token |
|
self._label = label |
|
self._weight = weight |
|
self._multivalent_tokens = multivalent_tokens |
|
self._fill_with_defaults() |
|
|
|
@property |
|
def token(self): |
|
if self._multivalent_tokens: |
|
raise TypeError('Timestep may contain multiple values; use `tokens`') |
|
return self._token.int64_list.value[0] |
|
|
|
@property |
|
def tokens(self): |
|
return self._token.int64_list.value |
|
|
|
@property |
|
def label(self): |
|
return self._label.int64_list.value[0] |
|
|
|
@property |
|
def weight(self): |
|
return self._weight.float_list.value[0] |
|
|
|
def set_token(self, token): |
|
if self._multivalent_tokens: |
|
raise TypeError('Timestep may contain multiple values; use `add_token`') |
|
self._token.int64_list.value[0] = token |
|
return self |
|
|
|
def add_token(self, token): |
|
self._token.int64_list.value.append(token) |
|
return self |
|
|
|
def set_label(self, label): |
|
self._label.int64_list.value[0] = label |
|
return self |
|
|
|
def set_weight(self, weight): |
|
self._weight.float_list.value[0] = weight |
|
return self |
|
|
|
def copy_from(self, timestep): |
|
self.set_token(timestep.token).set_label(timestep.label).set_weight( |
|
timestep.weight) |
|
return self |
|
|
|
def _fill_with_defaults(self): |
|
if not self._multivalent_tokens: |
|
self._token.int64_list.value.append(0) |
|
self._label.int64_list.value.append(0) |
|
self._weight.float_list.value.append(0.0) |
|
|
|
|
|
class SequenceWrapper(object): |
|
"""Wrapper around tf.SequenceExample.""" |
|
|
|
F_TOKEN_ID = 'token_id' |
|
F_LABEL = 'label' |
|
F_WEIGHT = 'weight' |
|
|
|
def __init__(self, multivalent_tokens=False): |
|
self._seq = tf.train.SequenceExample() |
|
self._flist = self._seq.feature_lists.feature_list |
|
self._timesteps = [] |
|
self._multivalent_tokens = multivalent_tokens |
|
|
|
@property |
|
def seq(self): |
|
return self._seq |
|
|
|
@property |
|
def multivalent_tokens(self): |
|
return self._multivalent_tokens |
|
|
|
@property |
|
def _tokens(self): |
|
return self._flist[SequenceWrapper.F_TOKEN_ID].feature |
|
|
|
@property |
|
def _labels(self): |
|
return self._flist[SequenceWrapper.F_LABEL].feature |
|
|
|
@property |
|
def _weights(self): |
|
return self._flist[SequenceWrapper.F_WEIGHT].feature |
|
|
|
def add_timestep(self): |
|
timestep = Timestep( |
|
self._tokens.add(), |
|
self._labels.add(), |
|
self._weights.add(), |
|
multivalent_tokens=self._multivalent_tokens) |
|
self._timesteps.append(timestep) |
|
return timestep |
|
|
|
def __iter__(self): |
|
for timestep in self._timesteps: |
|
yield timestep |
|
|
|
def __len__(self): |
|
return len(self._timesteps) |
|
|
|
def __getitem__(self, idx): |
|
return self._timesteps[idx] |
|
|
|
|
|
def build_reverse_sequence(seq): |
|
"""Builds a sequence that is the reverse of the input sequence.""" |
|
reverse_seq = SequenceWrapper() |
|
|
|
|
|
for timestep in reversed(seq[:-1]): |
|
reverse_seq.add_timestep().copy_from(timestep) |
|
|
|
|
|
reverse_seq.add_timestep().copy_from(seq[-1]) |
|
|
|
return reverse_seq |
|
|
|
|
|
def build_bidirectional_seq(seq, rev_seq): |
|
bidir_seq = SequenceWrapper(multivalent_tokens=True) |
|
for forward_ts, reverse_ts in zip(seq, rev_seq): |
|
bidir_seq.add_timestep().add_token(forward_ts.token).add_token( |
|
reverse_ts.token) |
|
|
|
return bidir_seq |
|
|
|
|
|
def build_lm_sequence(seq): |
|
"""Builds language model sequence from input sequence. |
|
|
|
Args: |
|
seq: SequenceWrapper. |
|
|
|
Returns: |
|
SequenceWrapper with `seq` tokens copied over to output sequence tokens and |
|
labels (offset by 1, i.e. predict next token) with weights set to 1.0, |
|
except for <eos> token. |
|
""" |
|
lm_seq = SequenceWrapper() |
|
for i, timestep in enumerate(seq): |
|
if i == len(seq) - 1: |
|
lm_seq.add_timestep().set_token(timestep.token).set_label( |
|
seq[i].token).set_weight(0.0) |
|
else: |
|
lm_seq.add_timestep().set_token(timestep.token).set_label( |
|
seq[i + 1].token).set_weight(1.0) |
|
return lm_seq |
|
|
|
|
|
def build_seq_ae_sequence(seq): |
|
"""Builds seq_ae sequence from input sequence. |
|
|
|
Args: |
|
seq: SequenceWrapper. |
|
|
|
Returns: |
|
SequenceWrapper with `seq` inputs copied and concatenated, and with labels |
|
copied in on the right-hand (i.e. decoder) side with weights set to 1.0. |
|
The new sequence will have length `len(seq) * 2 - 1`, as the last timestep |
|
of the encoder section and the first step of the decoder section will |
|
overlap. |
|
""" |
|
seq_ae_seq = SequenceWrapper() |
|
|
|
for i in range(len(seq) * 2 - 1): |
|
ts = seq_ae_seq.add_timestep() |
|
|
|
if i < len(seq) - 1: |
|
|
|
ts.set_token(seq[i].token) |
|
elif i == len(seq) - 1: |
|
|
|
ts.set_token(seq[i].token) |
|
ts.set_label(seq[0].token) |
|
ts.set_weight(1.0) |
|
else: |
|
|
|
ts.set_token(seq[i % len(seq)].token) |
|
ts.set_label(seq[(i + 1) % len(seq)].token) |
|
ts.set_weight(1.0) |
|
|
|
return seq_ae_seq |
|
|
|
|
|
def build_labeled_sequence(seq, class_label, label_gain=False): |
|
"""Builds labeled sequence from input sequence. |
|
|
|
Args: |
|
seq: SequenceWrapper. |
|
class_label: integer, starting from 0. |
|
label_gain: bool. If True, class_label will be put on every timestep and |
|
weight will increase linearly from 0 to 1. |
|
|
|
Returns: |
|
SequenceWrapper with `seq` copied in and `class_label` added as label to |
|
final timestep. |
|
""" |
|
label_seq = SequenceWrapper(multivalent_tokens=seq.multivalent_tokens) |
|
|
|
|
|
seq_len = len(seq) |
|
final_timestep = None |
|
for i, timestep in enumerate(seq): |
|
label_timestep = label_seq.add_timestep() |
|
if seq.multivalent_tokens: |
|
for token in timestep.tokens: |
|
label_timestep.add_token(token) |
|
else: |
|
label_timestep.set_token(timestep.token) |
|
if label_gain: |
|
label_timestep.set_label(int(class_label)) |
|
weight = 1.0 if seq_len < 2 else float(i) / (seq_len - 1) |
|
label_timestep.set_weight(weight) |
|
if i == (seq_len - 1): |
|
final_timestep = label_timestep |
|
|
|
|
|
final_timestep.set_label(int(class_label)).set_weight(1.0) |
|
|
|
return label_seq |
|
|
|
|
|
def split_by_punct(segment): |
|
"""Splits str segment by punctuation, filters our empties and spaces.""" |
|
return [s for s in re.split(r'\W+', segment) if s and not s.isspace()] |
|
|
|
|
|
def sort_vocab_by_frequency(vocab_freq_map): |
|
"""Sorts vocab_freq_map by count. |
|
|
|
Args: |
|
vocab_freq_map: dict<str term, int count>, vocabulary terms with counts. |
|
|
|
Returns: |
|
list<tuple<str term, int count>> sorted by count, descending. |
|
""" |
|
return sorted( |
|
vocab_freq_map.items(), key=operator.itemgetter(1), reverse=True) |
|
|
|
|
|
def write_vocab_and_frequency(ordered_vocab_freqs, output_dir): |
|
"""Writes ordered_vocab_freqs into vocab.txt and vocab_freq.txt.""" |
|
tf.gfile.MakeDirs(output_dir) |
|
with open(os.path.join(output_dir, 'vocab.txt'), 'w', encoding='utf-8') as vocab_f: |
|
with open(os.path.join(output_dir, 'vocab_freq.txt'), 'w', encoding='utf-8') as freq_f: |
|
for word, freq in ordered_vocab_freqs: |
|
vocab_f.write('{}\n'.format(word)) |
|
freq_f.write('{}\n'.format(freq)) |
|
|