Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from torch import tensor
|
|
8 |
|
9 |
import joblib
|
10 |
from dataclasses import dataclass
|
11 |
-
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer
|
12 |
import json
|
13 |
|
14 |
from preprocessing import predict_review, data_preprocessing_hard
|
@@ -191,6 +191,7 @@ elif selected_model == "Генерация текста GPT-моделью по
|
|
191 |
model = GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)
|
192 |
model = GPT2LMHeadModel(config)
|
193 |
model.load_state_dict(torch.load('model_dict.pt', map_location=device))
|
|
|
194 |
|
195 |
if st.button('Сделать гороскоп'):
|
196 |
start_time = time.time()
|
|
|
8 |
|
9 |
import joblib
|
10 |
from dataclasses import dataclass
|
11 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
|
12 |
import json
|
13 |
|
14 |
from preprocessing import predict_review, data_preprocessing_hard
|
|
|
191 |
model = GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)
|
192 |
model = GPT2LMHeadModel(config)
|
193 |
model.load_state_dict(torch.load('model_dict.pt', map_location=device))
|
194 |
+
model.eval()
|
195 |
|
196 |
if st.button('Сделать гороскоп'):
|
197 |
start_time = time.time()
|