lilingxi01 commited on
Commit
8b850ac
·
1 Parent(s): 74abd6a

[ERCBCM] Optimize the model interfaces and prints.

Browse files
Files changed (3) hide show
  1. ercbcm/ERCBCM.py +0 -1
  2. ercbcm/__init__.py +2 -3
  3. ercbcm/model_loader.py +4 -31
ercbcm/ERCBCM.py CHANGED
@@ -6,7 +6,6 @@ class ERCBCM(nn.Module):
6
  def __init__(self):
7
  super(ERCBCM, self).__init__()
8
  print('>>> ERCBCM Init!')
9
-
10
  self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
11
 
12
  def forward(self, text, label):
 
6
  def __init__(self):
7
  super(ERCBCM, self).__init__()
8
  print('>>> ERCBCM Init!')
 
9
  self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
10
 
11
  def forward(self, text, label):
ercbcm/__init__.py CHANGED
@@ -7,17 +7,16 @@ sys.path.insert(0, myPath + '/../')
7
 
8
  import torch
9
 
10
- from ercbcm.model_loader import load_checkpoint
11
  from ercbcm.ERCBCM import ERCBCM
12
  from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- print('>>> GPU Available?', torch.cuda.is_available())
16
 
17
  # ==========
18
 
19
  model_for_predict = ERCBCM().to(device)
20
- load_checkpoint('ercbcm/model.pt', model_for_predict, device)
21
 
22
  def predict(sentence, name):
23
  label = torch.tensor([0])
 
7
 
8
  import torch
9
 
10
+ from ercbcm.model_loader import load
11
  from ercbcm.ERCBCM import ERCBCM
12
  from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
15
 
16
  # ==========
17
 
18
  model_for_predict = ERCBCM().to(device)
19
+ load('ercbcm/model.pt', model_for_predict, device)
20
 
21
  def predict(sentence, name):
22
  label = torch.tensor([0])
ercbcm/model_loader.py CHANGED
@@ -1,35 +1,8 @@
1
  import torch
2
 
3
- # Save and Load Functions.
4
-
5
- def save_checkpoint(save_path, model, valid_loss):
6
- if save_path == None:
7
- return
8
- state_dict = {'model_state_dict': model.state_dict(),
9
- 'valid_loss': valid_loss}
10
- torch.save(state_dict, save_path)
11
- print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
12
-
13
- def load_checkpoint(load_path, model, device):
14
- if load_path == None:
15
- return
16
  state_dict = torch.load(load_path, map_location=device)
17
- print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
18
  model.load_state_dict(state_dict['model_state_dict'])
19
- return state_dict['valid_loss']
20
-
21
- def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
22
- if save_path == None:
23
- return
24
- state_dict = {'train_loss_list': train_loss_list,
25
- 'valid_loss_list': valid_loss_list,
26
- 'global_steps_list': global_steps_list}
27
- torch.save(state_dict, save_path)
28
- print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
29
-
30
- def load_metrics(load_path, device):
31
- if load_path == None:
32
- return
33
- state_dict = torch.load(load_path, map_location=device)
34
- print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
35
- return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
 
1
  import torch
2
 
3
+ def load(load_path, model, device):
4
+ if load_path == None: return
 
 
 
 
 
 
 
 
 
 
 
5
  state_dict = torch.load(load_path, map_location=device)
 
6
  model.load_state_dict(state_dict['model_state_dict'])
7
+ print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
8
+ return state_dict['valid_loss']