|
--- |
|
tags: |
|
- pytorch_model_hub_mixin |
|
- model_hub_mixin |
|
--- |
|
|
|
This model adds a classification head on top of [LofiAmazon/BarcodeBERT-Entire-BOLD](https://huggingface.co/LofiAmazon/BarcodeBERT-Entire-BOLD), The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with |
|
[BOLD-Embeddings-Ecolayers-Amazon](https://huggingface.co/datasets/LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon) to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location. |
|
|
|
## Example Usage |
|
|
|
First you need to download the `StandardScaler` used to normalise environmental layer values during training. You can find it [here](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/scaler.pkl). |
|
|
|
You will also need to download all ecological layers from [here](https://huggingface.co/datasets/LofiAmazon/Global-Ecolayers/tree/main). |
|
|
|
The model will output a probability distribution over genera that were present in the training dataset. You can find the mapping between the index in the vector and the genus name in [this file](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/genus_labels.json). |
|
|
|
```py |
|
import pickle |
|
|
|
from transformers import PreTrainedTokenizerFast, BertForMaskedLM, BertConfig |
|
import rasterio |
|
from rasterio.sample import sample_gen |
|
|
|
|
|
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, bert_model, env_dim, num_classes): |
|
super(DNASeqClassifier, self).__init__() |
|
self.bert = bert_model |
|
self.env_dim = env_dim |
|
self.num_classes = num_classes |
|
self.fc = nn.Linear(768 + env_dim, num_classes) |
|
|
|
def forward(self, bert_inputs, env_data): |
|
outputs = self.bert(**bert_inputs) |
|
dna_embeddings = outputs.hidden_states[-1].mean(1) |
|
combined = torch.cat((dna_embeddings, env_data), dim=1) |
|
logits = self.fc(combined) |
|
|
|
return logits |
|
|
|
classification_model = DNASeqClassifier.from_pretrained( |
|
"LofiAmazon/BarcodeBERT-Finetuned-Amazon", |
|
bert_model=BertForMaskedLM( |
|
BertConfig(vocab_size=259, output_hidden_states=True), |
|
), |
|
) |
|
|
|
ecolayers = [ |
|
"median_elevation_1km.tiff", |
|
"human_footprint.tiff", |
|
"population_density_1km.tif", |
|
"annual_precipitation.tif", |
|
"precipitation_seasonality.tif", |
|
"annual_mean_air_temp.tif", |
|
"temp_seasonality.tif", |
|
] |
|
|
|
with open("scaler.pkl", "rb") as f: |
|
scaler = pickle.load(f) |
|
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("LofiAmazon/BarcodeBERT-Entire-BOLD") |
|
|
|
# The DNA sequence you want to predict. |
|
# There should be a space after every 4 characters. |
|
# The sequence may also have unknown characters which are not A,C,T,G. |
|
# The maximum DNA sequence length (not counting spaces) should be 660 characters |
|
dna_sequence = "AACA ATGT ATTT A-T- TTCG CCCT TGTG AATT TATT ..." |
|
# Location where DNA was sampled |
|
coords = (-3.009083, -58.68281) |
|
|
|
# Tokenize the DNA sequence |
|
dna_tokenized = tokenizer(dna_sequence, return_tensors="pt") |
|
|
|
# Obtain the environmental data from the coordinates |
|
env_data = [] |
|
for layer in ecolayers: |
|
with rasterio.open(layer) as dataset: |
|
# Get the corresponding ecological values for the samples |
|
results = sample_gen(dataset, [coords]) |
|
results = [r for r in results] |
|
layer_data = np.mean(results[0]) |
|
env_data.append(layer_data) |
|
env_data = scaler.transform([env_data]) |
|
env_data = torch.from_numpy(env_data).to(torch.float32) |
|
|
|
# Obtain genus prediction |
|
logits = classification_model(dna_tokenized, env_data) |
|
temperature = 0.2 |
|
|
|
# Obtain the final genus probabilities |
|
probs = torch.softmax(logits / temperature, dim=1).squeeze() |
|
``` |