vshulev's picture
Update README.md
857c824 verified
---
tags:
- pytorch_model_hub_mixin
- model_hub_mixin
---
This model adds a classification head on top of [LofiAmazon/BarcodeBERT-Entire-BOLD](https://huggingface.co/LofiAmazon/BarcodeBERT-Entire-BOLD), The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with
[BOLD-Embeddings-Ecolayers-Amazon](https://huggingface.co/datasets/LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon) to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location.
## Example Usage
First you need to download the `StandardScaler` used to normalise environmental layer values during training. You can find it [here](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/scaler.pkl).
You will also need to download all ecological layers from [here](https://huggingface.co/datasets/LofiAmazon/Global-Ecolayers/tree/main).
The model will output a probability distribution over genera that were present in the training dataset. You can find the mapping between the index in the vector and the genus name in [this file](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/genus_labels.json).
```py
import pickle
from transformers import PreTrainedTokenizerFast, BertForMaskedLM, BertConfig
import rasterio
from rasterio.sample import sample_gen
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, bert_model, env_dim, num_classes):
super(DNASeqClassifier, self).__init__()
self.bert = bert_model
self.env_dim = env_dim
self.num_classes = num_classes
self.fc = nn.Linear(768 + env_dim, num_classes)
def forward(self, bert_inputs, env_data):
outputs = self.bert(**bert_inputs)
dna_embeddings = outputs.hidden_states[-1].mean(1)
combined = torch.cat((dna_embeddings, env_data), dim=1)
logits = self.fc(combined)
return logits
classification_model = DNASeqClassifier.from_pretrained(
"LofiAmazon/BarcodeBERT-Finetuned-Amazon",
bert_model=BertForMaskedLM(
BertConfig(vocab_size=259, output_hidden_states=True),
),
)
ecolayers = [
"median_elevation_1km.tiff",
"human_footprint.tiff",
"population_density_1km.tif",
"annual_precipitation.tif",
"precipitation_seasonality.tif",
"annual_mean_air_temp.tif",
"temp_seasonality.tif",
]
with open("scaler.pkl", "rb") as f:
scaler = pickle.load(f)
tokenizer = PreTrainedTokenizerFast.from_pretrained("LofiAmazon/BarcodeBERT-Entire-BOLD")
# The DNA sequence you want to predict.
# There should be a space after every 4 characters.
# The sequence may also have unknown characters which are not A,C,T,G.
# The maximum DNA sequence length (not counting spaces) should be 660 characters
dna_sequence = "AACA ATGT ATTT A-T- TTCG CCCT TGTG AATT TATT ..."
# Location where DNA was sampled
coords = (-3.009083, -58.68281)
# Tokenize the DNA sequence
dna_tokenized = tokenizer(dna_sequence, return_tensors="pt")
# Obtain the environmental data from the coordinates
env_data = []
for layer in ecolayers:
with rasterio.open(layer) as dataset:
# Get the corresponding ecological values for the samples
results = sample_gen(dataset, [coords])
results = [r for r in results]
layer_data = np.mean(results[0])
env_data.append(layer_data)
env_data = scaler.transform([env_data])
env_data = torch.from_numpy(env_data).to(torch.float32)
# Obtain genus prediction
logits = classification_model(dna_tokenized, env_data)
temperature = 0.2
# Obtain the final genus probabilities
probs = torch.softmax(logits / temperature, dim=1).squeeze()
```