Ruslan-DS commited on
Commit
9e3fb53
·
1 Parent(s): 90f9180

Update models/BertTunning.py

Browse files
Files changed (1) hide show
  1. models/BertTunning.py +2 -2
models/BertTunning.py CHANGED
@@ -4,7 +4,7 @@ from models.preprocess_stage.bert_model import model
4
  from models.preprocess_stage.bert_model import preprocess_bert
5
 
6
  MAX_LEN = 100
7
- DEVICE='cpu'
8
 
9
  class BertTunnig(nn.Module):
10
  def __init__(self, bert_model):
@@ -41,7 +41,7 @@ def predict_2(text):
41
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
42
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
43
 
44
- model_tunning.to(DEVICE)
45
  with torch.inference_mode():
46
 
47
  predict = round(model_tunning(preprocessed_text, attention_mask=attention_mask).item())
 
4
  from models.preprocess_stage.bert_model import preprocess_bert
5
 
6
  MAX_LEN = 100
7
+ # DEVICE='cpu'
8
 
9
  class BertTunnig(nn.Module):
10
  def __init__(self, bert_model):
 
41
  preprocessed_text, attention_mask = preprocess_bert(text, MAX_LEN=MAX_LEN)
42
  preprocessed_text, attention_mask = torch.tensor(preprocessed_text).unsqueeze(0), torch.tensor([attention_mask])
43
 
44
+ # model_tunning.to(DEVICE)
45
  with torch.inference_mode():
46
 
47
  predict = round(model_tunning(preprocessed_text, attention_mask=attention_mask).item())