Spaces:
Runtime error
Runtime error
Commit
·
8b850ac
1
Parent(s):
74abd6a
[ERCBCM] Optimize the model interfaces and prints.
Browse files- ercbcm/ERCBCM.py +0 -1
- ercbcm/__init__.py +2 -3
- ercbcm/model_loader.py +4 -31
ercbcm/ERCBCM.py
CHANGED
@@ -6,7 +6,6 @@ class ERCBCM(nn.Module):
|
|
6 |
def __init__(self):
|
7 |
super(ERCBCM, self).__init__()
|
8 |
print('>>> ERCBCM Init!')
|
9 |
-
|
10 |
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
11 |
|
12 |
def forward(self, text, label):
|
|
|
6 |
def __init__(self):
|
7 |
super(ERCBCM, self).__init__()
|
8 |
print('>>> ERCBCM Init!')
|
|
|
9 |
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
10 |
|
11 |
def forward(self, text, label):
|
ercbcm/__init__.py
CHANGED
@@ -7,17 +7,16 @@ sys.path.insert(0, myPath + '/../')
|
|
7 |
|
8 |
import torch
|
9 |
|
10 |
-
from ercbcm.model_loader import
|
11 |
from ercbcm.ERCBCM import ERCBCM
|
12 |
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
13 |
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
-
print('>>> GPU Available?', torch.cuda.is_available())
|
16 |
|
17 |
# ==========
|
18 |
|
19 |
model_for_predict = ERCBCM().to(device)
|
20 |
-
|
21 |
|
22 |
def predict(sentence, name):
|
23 |
label = torch.tensor([0])
|
|
|
7 |
|
8 |
import torch
|
9 |
|
10 |
+
from ercbcm.model_loader import load
|
11 |
from ercbcm.ERCBCM import ERCBCM
|
12 |
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
13 |
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
15 |
|
16 |
# ==========
|
17 |
|
18 |
model_for_predict = ERCBCM().to(device)
|
19 |
+
load('ercbcm/model.pt', model_for_predict, device)
|
20 |
|
21 |
def predict(sentence, name):
|
22 |
label = torch.tensor([0])
|
ercbcm/model_loader.py
CHANGED
@@ -1,35 +1,8 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
def save_checkpoint(save_path, model, valid_loss):
|
6 |
-
if save_path == None:
|
7 |
-
return
|
8 |
-
state_dict = {'model_state_dict': model.state_dict(),
|
9 |
-
'valid_loss': valid_loss}
|
10 |
-
torch.save(state_dict, save_path)
|
11 |
-
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
12 |
-
|
13 |
-
def load_checkpoint(load_path, model, device):
|
14 |
-
if load_path == None:
|
15 |
-
return
|
16 |
state_dict = torch.load(load_path, map_location=device)
|
17 |
-
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
18 |
model.load_state_dict(state_dict['model_state_dict'])
|
19 |
-
|
20 |
-
|
21 |
-
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
22 |
-
if save_path == None:
|
23 |
-
return
|
24 |
-
state_dict = {'train_loss_list': train_loss_list,
|
25 |
-
'valid_loss_list': valid_loss_list,
|
26 |
-
'global_steps_list': global_steps_list}
|
27 |
-
torch.save(state_dict, save_path)
|
28 |
-
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
29 |
-
|
30 |
-
def load_metrics(load_path, device):
|
31 |
-
if load_path == None:
|
32 |
-
return
|
33 |
-
state_dict = torch.load(load_path, map_location=device)
|
34 |
-
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
35 |
-
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
|
|
1 |
import torch
|
2 |
|
3 |
+
def load(load_path, model, device):
|
4 |
+
if load_path == None: return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
state_dict = torch.load(load_path, map_location=device)
|
|
|
6 |
model.load_state_dict(state_dict['model_state_dict'])
|
7 |
+
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
8 |
+
return state_dict['valid_loss']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|