Ruslan-DS commited on
Commit
d64686c
·
1 Parent(s): 4718e6a

Update models/BertTunning.py

Browse files
Files changed (1) hide show
  1. models/BertTunning.py +7 -11
models/BertTunning.py CHANGED
@@ -1,17 +1,13 @@
1
- import pandas as pd
2
- import numpy as np
3
  import torch
4
  from torch import nn
5
- import torch.nn.functional as F
6
 
7
-
8
- from logreg_model import bert_for_logreg, tokenizer_bert
9
- from preprocess_bert import preprocess_bert
10
 
11
  MAX_LEN = 100
12
 
13
- class BertTunnig(nn.Module):
14
 
 
15
  def __init__(self, bert_model):
16
  super().__init__()
17
 
@@ -37,9 +33,9 @@ class BertTunnig(nn.Module):
37
  return torch.sigmoid(output)
38
 
39
 
40
- model_tunning = BertTunnig(bert_model=bert_for_logreg)
 
41
 
42
- model_tunning.load_state_dict(torch.load('best_weights_berttinnug(2).pt'))
43
 
44
  def predict_2(text):
45
 
@@ -48,6 +44,6 @@ def predict_2(text):
48
 
49
  with torch.inference_mode():
50
 
51
- predict = model_tunning(preprocessed_text, attention_mask=attention_mask).item()
52
 
53
- return round(predict)
 
 
 
1
  import torch
2
  from torch import nn
 
3
 
4
+ from models.preprocess_stage.bert_model import model
5
+ from models.preprocess_stage.bert_model import preprocess_bert
 
6
 
7
  MAX_LEN = 100
8
 
 
9
 
10
+ class BertTunnig(nn.Module):
11
  def __init__(self, bert_model):
12
  super().__init__()
13
 
 
33
  return torch.sigmoid(output)
34
 
35
 
36
+ model_tunning = BertTunnig(bert_model=model)
37
+ model_tunning.load_state_dict(torch.load('models/weights/BertTunnigWeights.pt'))
38
 
 
39
 
40
  def predict_2(text):
41
 
 
44
 
45
  with torch.inference_mode():
46
 
47
+ predict = round(model_tunning(preprocessed_text, attention_mask=attention_mask).item())
48
 
49
+ return predict