File size: 474 Bytes
a0b398e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from everything import *
from bert import BertModel

def get_finetuned_bert(mode: str):
    assert mode in ['sup', 'unsup']

    bert = BertModel.from_pretrained('bert-base-uncased')
    if mode == 'sup':
        state_dict = torch.load(SUP_BERT, weights_only=True)
    else:
        state_dict = torch.load(UNSUP_BERT, weights_only=True)
    device = torch.device('cuda') if USE_GPU else torch.device('cpu')

    bert.load_state_dict(state_dict)
    return bert.to(device)