This model adds a classification head on top of LofiAmazon/BarcodeBERT-Entire-BOLD, The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with BOLD-Embeddings-Ecolayers-Amazon to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location.
Example Usage
First you need to download the StandardScaler
used to normalise environmental layer values during training. You can find it here.
You will also need to download all ecological layers from here.
The model will output a probability distribution over genera that were present in the training dataset. You can find the mapping between the index in the vector and the genus name in this file.
import pickle
from transformers import PreTrainedTokenizerFast, BertForMaskedLM, BertConfig
import rasterio
from rasterio.sample import sample_gen
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, bert_model, env_dim, num_classes):
super(DNASeqClassifier, self).__init__()
self.bert = bert_model
self.env_dim = env_dim
self.num_classes = num_classes
self.fc = nn.Linear(768 + env_dim, num_classes)
def forward(self, bert_inputs, env_data):
outputs = self.bert(**bert_inputs)
dna_embeddings = outputs.hidden_states[-1].mean(1)
combined = torch.cat((dna_embeddings, env_data), dim=1)
logits = self.fc(combined)
return logits
classification_model = DNASeqClassifier.from_pretrained(
"LofiAmazon/BarcodeBERT-Finetuned-Amazon",
bert_model=BertForMaskedLM(
BertConfig(vocab_size=259, output_hidden_states=True),
),
)
ecolayers = [
"median_elevation_1km.tiff",
"human_footprint.tiff",
"population_density_1km.tif",
"annual_precipitation.tif",
"precipitation_seasonality.tif",
"annual_mean_air_temp.tif",
"temp_seasonality.tif",
]
with open("scaler.pkl", "rb") as f:
scaler = pickle.load(f)
tokenizer = PreTrainedTokenizerFast.from_pretrained("LofiAmazon/BarcodeBERT-Entire-BOLD")
# The DNA sequence you want to predict.
# There should be a space after every 4 characters.
# The sequence may also have unknown characters which are not A,C,T,G.
# The maximum DNA sequence length (not counting spaces) should be 660 characters
dna_sequence = "AACA ATGT ATTT A-T- TTCG CCCT TGTG AATT TATT ..."
# Location where DNA was sampled
coords = (-3.009083, -58.68281)
# Tokenize the DNA sequence
dna_tokenized = tokenizer(dna_sequence, return_tensors="pt")
# Obtain the environmental data from the coordinates
env_data = []
for layer in ecolayers:
with rasterio.open(layer) as dataset:
# Get the corresponding ecological values for the samples
results = sample_gen(dataset, [coords])
results = [r for r in results]
layer_data = np.mean(results[0])
env_data.append(layer_data)
env_data = scaler.transform([env_data])
env_data = torch.from_numpy(env_data).to(torch.float32)
# Obtain genus prediction
logits = classification_model(dna_tokenized, env_data)
temperature = 0.2
# Obtain the final genus probabilities
probs = torch.softmax(logits / temperature, dim=1).squeeze()
- Downloads last month
- 11