Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import numpy as np | |
# Sets of KNOWN characters in SMILES and FASTA sequences | |
# Use list instead of set to preserve character order | |
SMILES_VOCAB = ('#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4', | |
'7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', | |
'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V', | |
'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i', | |
'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y') | |
FASTA_VOCAB = ('A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O', | |
'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z') | |
# Check uniqueness, create character-index dicts, and add '?' for unknown characters as index 0 | |
assert len(SMILES_VOCAB) == len(set(SMILES_VOCAB)), 'SMILES_CHARSET has duplicate characters.' | |
SMILES_CHARSET_IDX = {character: index+1 for index, character in enumerate(SMILES_VOCAB)} | {'?': 0} | |
assert len(FASTA_VOCAB) == len(set(FASTA_VOCAB)), 'FASTA_CHARSET has duplicate characters.' | |
FASTA_CHARSET_IDX = {character: index+1 for index, character in enumerate(FASTA_VOCAB)} | {'?': 0} | |
def sequence_to_onehot(sequence: str, charset, max_sequence_length: int): | |
assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.' | |
charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0} | |
onehot = np.zeros((max_sequence_length, len(charset_idx)), dtype=int) | |
for index, character in enumerate(sequence[:max_sequence_length]): | |
onehot[index, charset_idx.get(character, 0)] = 1 | |
return onehot.transpose() | |
def sequence_to_label(sequence: str, charset, max_sequence_length: int): | |
assert len(charset) == len(set(charset)), '`charset` contains duplicate characters.' | |
charset_idx = {character: index+1 for index, character in enumerate(charset)} | {'?': 0} | |
label = np.zeros(max_sequence_length, dtype=int) | |
for index, character in enumerate(sequence[:max_sequence_length]): | |
label[index] = charset_idx.get(character, 0) | |
return label | |
def smiles_to_onehot(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET) | |
# assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.' | |
# onehot = np.zeros((max_sequence_length, len(SMILES_CHARSET_IDX))) | |
# for index, character in enumerate(smiles[:max_sequence_length]): | |
# onehot[index, SMILES_CHARSET_IDX.get(character, 0)] = 1 | |
# return onehot.transpose() | |
return sequence_to_onehot(smiles, smiles_charset, max_sequence_length) | |
def smiles_to_label(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET) | |
# label = np.zeros(max_sequence_length) | |
# for index, character in enumerate(smiles[:max_sequence_length]): | |
# label[index] = SMILES_CHARSET_IDX.get(character, 0) | |
# return label | |
return sequence_to_label(smiles, smiles_charset, max_sequence_length) | |
def fasta_to_onehot(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET) | |
# onehot = np.zeros((max_sequence_length, len(FASTA_CHARSET_IDX))) | |
# for index, character in enumerate(fasta[:max_sequence_length]): | |
# onehot[index, FASTA_CHARSET_IDX.get(character, 0)] = 1 | |
# return onehot.transpose() | |
return sequence_to_onehot(fasta, fasta_charset, max_sequence_length) | |
def fasta_to_label(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET) | |
# label = np.zeros(max_sequence_length) | |
# for index, character in enumerate(fasta[:max_sequence_length]): | |
# label[index] = FASTA_CHARSET_IDX.get(character, 0) | |
# return label | |
return sequence_to_label(fasta, fasta_charset, max_sequence_length) | |
def one_of_k_encoding(x, allowable_set): | |
if x not in allowable_set: | |
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) | |
return list(map(lambda s: x == s, allowable_set)) | |
def one_of_k_encoding_unk(x, allowable_set): | |
"""Maps inputs not in the allowable set to the last element.""" | |
if x not in allowable_set: | |
x = allowable_set[-1] | |
return list(map(lambda s: x == s, allowable_set)) | |