minBERT / zemo3.py
GlowCheese's picture
First model version
9756d99
raw
history blame
828 Bytes
import torch
from tokenizer import BertTokenizer
from torch import nn
from bert import BertModel
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Example sentence
sentences = [
"She loves reading novels in her free time",
"An apple a day keeps the doctor away",
"If you can't explain it simply, you don't understand it well enough."
]
# Tokenize and encode the sentence
encoding = tokenizer.batch_encode_plus(
sentences,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Get the token IDs from the encoding
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
model = BertModel.from_pretrained('bert-base-uncased')
assert isinstance(model, BertModel)
print(model.embed(input_ids).size())