|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
model_name = 'Yerzhxn/class_space' |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
st.title("Тестирование классификации текста") |
|
st.write("Введите текст, чтобы узнать предсказанный класс.") |
|
|
|
|
|
input_text = st.text_area("Введите текст здесь", "") |
|
|
|
if st.button("Предсказать"): |
|
if input_text: |
|
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class = torch.argmax(logits, dim=1).item() |
|
|
|
|
|
st.write(f"Предсказанный класс: {predicted_class}") |
|
else: |
|
st.write("Пожалуйста, введите текст для классификации.") |
|
|