gzhong's picture
Upload folder using huggingface_hub
7718235 verified
import os
import tensorflow.compat.v1 as tf
def tf_str_len(s):
"""
Returns length of tf.string s
"""
return tf.size(tf.string_split([s],""))
def tf_rank1_tensor_len(t):
"""
Returns the length of a rank 1 tensor t as rank 0 int32
"""
l = tf.reduce_sum(tf.sign(tf.abs(t)), 0)
return tf.cast(l, tf.int32)
def tf_seq_to_tensor(s):
"""
Input a tf.string of comma seperated integers.
Returns Rank 1 tensor the length of the input sequence of type int32
"""
return tf.string_to_number(
tf.sparse_tensor_to_dense(tf.string_split([s],","), default_value='0'), out_type=tf.int32
)[0]
def smart_length(length, bucket_bounds=tf.constant([128, 256])):
"""
Hash the given length into the windows given by bucket bounds.
"""
# num_buckets = tf_len(bucket_bounds) + tf.constant(1)
# Subtract length so that smaller bins are negative, then take sign
# Eg: len is 129, sign = [-1,1]
signed = tf.sign(bucket_bounds - length)
# Now make 1 everywhere that length is greater than bound, else 0
greater = tf.sign(tf.abs(signed - tf.constant(1)))
# Now simply sum to count the number of bounds smaller than length
key = tf.cast(tf.reduce_sum(greater), tf.int64)
# This will be between 0 and len(bucket_bounds)
return key
def pad_batch(ds, batch_size, padding=None, padded_shapes=([None])):
"""
Helper for bucket batch pad- pads with zeros
"""
return ds.padded_batch(batch_size,
padded_shapes=padded_shapes,
padding_values=padding
)
def aas_to_int_seq(aa_seq):
int_seq = ""
for aa in aa_seq:
int_seq += str(aa_to_int[aa]) + ","
return str(aa_to_int['start']) + "," + int_seq + str(aa_to_int['stop'])
# Preprocessing in python
def fasta_to_input_format(source, destination):
# I don't know exactly how to do this in tf, so resorting to python.
# Should go line by line so everything is not loaded into memory
sourcefile = os.path.join(source)
destination = os.path.join(destiation)
with open(sourcefile, 'r') as f:
with open(destination, 'w') as dest:
seq = ""
for line in f:
if line[0] == '>' and not seq == "":
dest.write(aas_to_int_seq(seq) + '\n')
seq = ""
elif not line[0] == '>':
seq += line.replace("\n","")
# Real data pipelines
def bucketbatchpad(
batch_size=256,
path_to_data=os.path.join("./data/SwissProt/sprot_ints.fasta"), # Preprocessed- see note
compressed="", # See tf.contrib.data.TextLineDataset init args
bounds=[128,256], # Default buckets of < 128, 128><256, >256
# Unclear exactly what this does, should proly equal batchsize
window_size=256, # NOT a tensor
padding=None, # Use default padding of zero, otherwise see Dataset docs
shuffle_buffer=None, # None or the size of the buffer to shuffle with
pad_shape=([None]),
repeat=1,
filt=None
):
"""
Streams data from path_to_data that is correctly preprocessed.
Divides into buckets given by bounds and pads to full length.
Returns a dataset which will return a padded batch of batchsize
with iteration.
"""
batch_size=tf.constant(batch_size, tf.int64)
bounds=tf.constant(bounds)
window_size=tf.constant(window_size, tf.int64)
path_to_data = os.path.join(path_to_data)
# Parse strings to tensors
dataset = tf.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor)
if filt is not None:
dataset = dataset.filter(filt)
if shuffle_buffer:
# Stream elements uniformly randomly from a buffer
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Apply a repeat. Because this is after the shuffle, all elements of the dataset should be seen before repeat.
# See https://stackoverflow.com/questions/44132307/tf-contrib-data-dataset-repeat-with-shuffle-notice-epoch-end-mixed-epochs
dataset = dataset.repeat(count=repeat)
# Apply grouping to bucket and pad
group_fn = tf.data.experimental.group_by_window(
key_func=lambda seq: smart_length(tf_rank1_tensor_len(seq), bucket_bounds=bounds), # choose a bucket
reduce_func=lambda key, ds: pad_batch(ds, batch_size, padding=padding, padded_shapes=pad_shape), # apply reduce funtion to pad
window_size=window_size)
grouped_dataset = dataset.apply(group_fn)
return grouped_dataset
def shufflebatch(
batch_size=256,
shuffle_buffer=None,
repeat=1,
path_to_data="./data/SwissProt/sprot_ints.fasta"
):
"""
Draws from an (optionally shuffled) dataset, repeats dataset repeat times,
and serves batches of the specified size.
"""
path_to_data = os.path.join(path_to_data)
# Parse strings to tensors
dataset = tf.contrib.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor)
if shuffle_buffer:
# Stream elements uniformly randomly from a buffer
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Apply a repeat. Because this is after the shuffle, all elements of the dataset should be seen before repeat.
# See https://stackoverflow.com/questions/44132307/tf-contrib-data-dataset-repeat-with-shuffle-notice-epoch-end-mixed-epochs
dataset = dataset.repeat(count=repeat)
dataset = dataset.batch(batch_size)
return dataset