|
--- |
|
license: apache-2.0 |
|
pipeline_tag: fill-mask |
|
tags: |
|
- feature-extraction |
|
- transformers |
|
widget: |
|
- text: Paris is the <mask> of France. |
|
example_title: Capital |
|
- text: The goal of life is <mask>. |
|
example_title: Philosophy |
|
metrics: |
|
- accuracy |
|
--- |
|
|
|
# π MGE-LLMs/SteelBERT π |
|
|
|
**SteelBERT** was pre-trained based on DeBERTa using a corpus of 4.2 million **materials** abstracts and 55,000 full-text steel articles, amounting to roughly 0.96 billion words. For self-supervised training, SteelBERT masked 15% of the tokens using Masked Language Modeling (MLM) β a universal and effective pretraining method for NLP tasks. |
|
|
|
SteelBERT was trained to predict the representation of masked words by adjusting parameters across various network layers. We allocated 95% of the corpus for training and 5% for validation. The validation loss reached 1.158 after 840 hours of training. |
|
|
|
### Why DeBERTa? π€ |
|
|
|
We chose the DeBERTa structure due to its innovative approach. DeBERTa introduces a disentangled attention mechanism that handles long-range dependencies, crucial for comprehending complex material interactions. |
|
|
|
The original DeBERTa model's extensive sub-word vocabulary could introduce noise during tokenization. To address this, we trained a specialized tokenizer, constructing a vocabulary specific to the steel domain. Despite the smaller training corpus, we maintained a consistent vocabulary scale of 128,100 words to capture latent knowledge. |
|
|
|
### Model Architecture ποΈ |
|
|
|
SteelBERT comprises 188 million parameters and is constructed using 12 stacked Transformer encoders, each with 12 attention heads. We used the original DeBERTa code with similar configurations and size. |
|
|
|
The maximum sentence length was set to 512 tokens, and training continued until the loss stopped decreasing. The pre-training procedure employed 8 NVIDIA A100 40GB GPUs for 840 hours, with a batch size of 576 sequences. |
|
|
|
### New Features π |
|
|
|
- **Specialized Tokenizer** π οΈ: Trained on a steel materials corpus to enhance accuracy, integrating insights from other material corpora. |
|
|
|
- **Consistent Vocabulary Scale** π: Maintained at 128,100 words to capture precise latent knowledge. |
|
|
|
- **Efficient Training Configuration** βοΈ: Utilized 8 NVIDIA A100 40GB GPUs for 840 hours with a batch size of 576 sequences. |
|
|
|
- **Enhanced Fine-tuning Capabilities** π―: Facilitates efficient fine-tuning for specific downstream tasks, enhancing practical application versatility. |
|
|
|
- **Disentangled Attention Mechanism** π§ : Effectively manages long-range dependencies, inherited from **DeBERTa**. |
|
|
|
### Usage Example π |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
# Check if GPU is available |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model_path = "MGE-LLMs/SteelBERT" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModel.from_pretrained(model_path).to(device) # Move model to GPU if available |
|
|
|
# Example list of texts |
|
texts = [ |
|
"A composite steel plate for marine construction was fabricated using 316L stainless steel.", |
|
"The use of composite materials in construction is growing rapidly.", |
|
"Advances in material science are leading to stronger and more durable steel products." |
|
] |
|
|
|
# Tokenize the texts |
|
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device) |
|
|
|
# Print tokenized texts |
|
for i, input_ids in enumerate(inputs['input_ids']): |
|
text_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
print(f"Tokens for text {i + 1}: {text_tokens}") |
|
|
|
# Get [CLS] embeddings for each input text |
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
|
hidden_states = outputs.hidden_states |
|
last_hidden_state = hidden_states[-1] |
|
cls_embeddings = last_hidden_state[:, 0, :] |
|
|
|
# Print the [CLS] token embeddings for each text |
|
print("CLS embeddings for each text:") |
|
print(cls_embeddings) |