nlp_project / task2.py
Tatiana
files added
dd3dbad
raw
history blame
3.13 kB
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
label_encoder = LabelEncoder()
label_encoder.fit(labels)
# Загрузка сохраненной модели и токенизатора в Streamlit
loaded_model_path = "rubert-base-cased"
loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path)
# Инициализация модели и токенизатора
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
# Создание модели с архитектурой BertForSequenceClassification
# Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию
model = BertForSequenceClassification(num_labels=len(labels))
# Загрузка весов из сохраненного файла
weights_path = "model_weights_epoch_8.pt"
state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU
model.load_state_dict(state_dict)
# Пример использования загруженной модели
user_input = "Ваш текст для классификации"
predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder)
print(predicted_class)
# #Загрузка сохраненной модели и токенизатора в Streamlit
# loaded_model_path = "nlp_project/model"
# loaded_tokenizer_path = "nlp_project/tokenizer"
# loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
# loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
if not user_input:
return "Введите текст"
def tokenize_text(text):
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoded_text
encoded_text = tokenize_text(user_input)
with torch.no_grad():
model.eval()
input_ids = encoded_text['input_ids']
attention_mask = encoded_text['attention_mask']
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class_index = torch.argmax(logits, dim=1).item()
# Получение названия класса
predicted_class = label_encoder.classes_[predicted_class_index]
return predicted_class