Model Overview: The model presented in this paper builds on the BigBird architecture with a similar approach detailed in our paper titled "Leveraging Large Language Models for Metagenomic Analysis" This model is optimized to enhance the performance of BigBird for large gene sequence data. Trained specifically on gene sequences, it aims to uncover valuable insights within metagenomic data and is evaluated across various tasks, including classification and sequence embedding.

Model Architecture:

  • Base Model: BigBird transformer architecture
  • Tokenizer: Custom K-mer Tokenizer with k-mer length of 6 and overlapping tokens
  • Training: Trained on a diverse dataset of 497 genes from 2000 bacterial and archaeal genomes
  • Embeddings: Generates sequence embeddings using mean pooling of hidden states

Dataset: Scorpio Gene-Taxa Benchmark Dataset: https://zenodo.org/records/12964684

Steps to Use the Model:

  1. Install KmerTokenizer:

  2. pip install git+https://github.com/MsAlEhR/KmerTokenizer.git
    
  3. Example Code:

     from KmerTokenizer import KmerTokenizer
     from transformers import AutoModel
     import torch
     
     # Example gene sequence
     seq = "ATTTTTTTTTTTCCCCCCCCCCCGGGGGGGGATCGATGC"
     
     # Initialize the tokenizer
     tokenizer = KmerTokenizer(kmerlen=6, overlapping=True, maxlen=4096)
     tokenized_output = tokenizer.kmer_tokenize(seq)
     pad_token_id = 2  # Set pad token ID
     
     # Create attention mask (1 for tokens, 0 for padding)
     attention_mask = torch.tensor([1 if token != pad_token_id else 0 for token in tokenized_output], dtype=torch.long).unsqueeze(0)
     
     # Convert tokenized output to LongTensor and add batch dimension
     inputs = torch.tensor([tokenized_output], dtype=torch.long)
     
     # Load the pre-trained BigBird model
     model = AutoModel.from_pretrained("MsAlEhR/MetaBERTa-bigbird-gene", output_hidden_states=True)
     
     # Generate hidden states
     outputs = model(input_ids=inputs, attention_mask=attention_mask)
     
     # Get embeddings from the last hidden state
     embeddings = outputs.hidden_states[-1]  
     
     # Expand attention mask to match the embedding dimensions
     expanded_attention_mask = attention_mask.unsqueeze(-1) 
     
     # Compute mean sequence embeddings
     mean_sequence_embeddings = torch.sum(expanded_attention_mask * embeddings, dim=1) / torch.sum(expanded_attention_mask, dim=1)
    

Citation: For a detailed overview of leveraging large language models for metagenomic analysis, refer to our paper:

Refahi, M.S., Sokhansanj, B.A., & Rosen, G.L. (2023). Leveraging Large Language Models for Metagenomic Analysis. IEEE SPMB.

Refahi, M., Sokhansanj, B.A., Mell, J.C., Brown, J., Yoo, H., Hearne, G. and Rosen, G., 2024. Scorpio: Enhancing Embeddings to Improve Downstream Analysis of DNA sequences. bioRxiv, pp.2024-07.

Downloads last month
14
Safetensors
Model size
35.8M params
Tensor type
F32
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.