|
import os |
|
import tensorflow.compat.v1 as tf |
|
|
|
def tf_str_len(s): |
|
""" |
|
Returns length of tf.string s |
|
""" |
|
return tf.size(tf.string_split([s],"")) |
|
|
|
def tf_rank1_tensor_len(t): |
|
""" |
|
Returns the length of a rank 1 tensor t as rank 0 int32 |
|
""" |
|
l = tf.reduce_sum(tf.sign(tf.abs(t)), 0) |
|
return tf.cast(l, tf.int32) |
|
|
|
|
|
def tf_seq_to_tensor(s): |
|
""" |
|
Input a tf.string of comma seperated integers. |
|
Returns Rank 1 tensor the length of the input sequence of type int32 |
|
""" |
|
return tf.string_to_number( |
|
tf.sparse_tensor_to_dense(tf.string_split([s],","), default_value='0'), out_type=tf.int32 |
|
)[0] |
|
|
|
def smart_length(length, bucket_bounds=tf.constant([128, 256])): |
|
""" |
|
Hash the given length into the windows given by bucket bounds. |
|
""" |
|
|
|
|
|
|
|
signed = tf.sign(bucket_bounds - length) |
|
|
|
|
|
greater = tf.sign(tf.abs(signed - tf.constant(1))) |
|
|
|
|
|
key = tf.cast(tf.reduce_sum(greater), tf.int64) |
|
|
|
|
|
return key |
|
|
|
def pad_batch(ds, batch_size, padding=None, padded_shapes=([None])): |
|
""" |
|
Helper for bucket batch pad- pads with zeros |
|
""" |
|
return ds.padded_batch(batch_size, |
|
padded_shapes=padded_shapes, |
|
padding_values=padding |
|
) |
|
|
|
def aas_to_int_seq(aa_seq): |
|
int_seq = "" |
|
for aa in aa_seq: |
|
int_seq += str(aa_to_int[aa]) + "," |
|
return str(aa_to_int['start']) + "," + int_seq + str(aa_to_int['stop']) |
|
|
|
|
|
def fasta_to_input_format(source, destination): |
|
|
|
|
|
|
|
sourcefile = os.path.join(source) |
|
destination = os.path.join(destiation) |
|
with open(sourcefile, 'r') as f: |
|
with open(destination, 'w') as dest: |
|
seq = "" |
|
for line in f: |
|
if line[0] == '>' and not seq == "": |
|
dest.write(aas_to_int_seq(seq) + '\n') |
|
seq = "" |
|
elif not line[0] == '>': |
|
seq += line.replace("\n","") |
|
|
|
|
|
|
|
def bucketbatchpad( |
|
batch_size=256, |
|
path_to_data=os.path.join("./data/SwissProt/sprot_ints.fasta"), |
|
compressed="", |
|
bounds=[128,256], |
|
|
|
window_size=256, |
|
padding=None, |
|
shuffle_buffer=None, |
|
pad_shape=([None]), |
|
repeat=1, |
|
filt=None |
|
): |
|
""" |
|
Streams data from path_to_data that is correctly preprocessed. |
|
Divides into buckets given by bounds and pads to full length. |
|
Returns a dataset which will return a padded batch of batchsize |
|
with iteration. |
|
""" |
|
batch_size=tf.constant(batch_size, tf.int64) |
|
bounds=tf.constant(bounds) |
|
window_size=tf.constant(window_size, tf.int64) |
|
|
|
path_to_data = os.path.join(path_to_data) |
|
|
|
dataset = tf.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor) |
|
if filt is not None: |
|
dataset = dataset.filter(filt) |
|
|
|
if shuffle_buffer: |
|
|
|
dataset = dataset.shuffle(buffer_size=shuffle_buffer) |
|
|
|
|
|
dataset = dataset.repeat(count=repeat) |
|
|
|
group_fn = tf.data.experimental.group_by_window( |
|
key_func=lambda seq: smart_length(tf_rank1_tensor_len(seq), bucket_bounds=bounds), |
|
reduce_func=lambda key, ds: pad_batch(ds, batch_size, padding=padding, padded_shapes=pad_shape), |
|
window_size=window_size) |
|
grouped_dataset = dataset.apply(group_fn) |
|
return grouped_dataset |
|
|
|
def shufflebatch( |
|
batch_size=256, |
|
shuffle_buffer=None, |
|
repeat=1, |
|
path_to_data="./data/SwissProt/sprot_ints.fasta" |
|
): |
|
""" |
|
Draws from an (optionally shuffled) dataset, repeats dataset repeat times, |
|
and serves batches of the specified size. |
|
""" |
|
|
|
path_to_data = os.path.join(path_to_data) |
|
|
|
dataset = tf.contrib.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor) |
|
if shuffle_buffer: |
|
|
|
dataset = dataset.shuffle(buffer_size=shuffle_buffer) |
|
|
|
|
|
dataset = dataset.repeat(count=repeat) |
|
dataset = dataset.batch(batch_size) |
|
return dataset |
|
|