Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import os | |
import re | |
import emoji | |
TOKEN = os.getenv("HF_TOKEN") | |
models = { | |
"ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13", | |
"ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big", | |
"ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small", | |
"ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14", | |
"ruSpamNS_v14_multiclass": "NeuroSpaceX/ruSpamNS_v14_multiclass", | |
"ruSpamNS_v16_multiclass": "NeuroSpaceX/ruSpamNS_v16_multiclass", | |
"ruSpamNS_v17_multiclass": "NeuroSpaceX/ruSpamNS_v17_multiclass", | |
"ruSpamNS_v19_multiclass": "NeuroSpaceX/ruSpamNS_v19_multiclass" | |
} | |
tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()} | |
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()} | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def clean_text(text): | |
text = emoji.replace_emoji(text, replace='') | |
text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE) | |
text = text.lower() | |
text = text.capitalize() | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def classify_text(text, model_choice): | |
tokenizer = tokenizers[model_choice] | |
model = models[model_choice] | |
message = clean_text(text) | |
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask).logits | |
if "multiclass" in model_choice: | |
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] | |
if model_choice == "ruSpamNS_v19_multiclass": | |
labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ/ТОВАРЫ", "ВАКАНСИИ", "РЕКЛАМА/УСЛУГИ"] | |
else: | |
labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ", "ВАКАНСИИ", "ПРОДАЖА"] | |
predicted_index = probabilities.argmax() | |
predicted_label = labels[predicted_index] | |
confidence = probabilities[predicted_index] * 100 | |
return f"{predicted_label} (вероятность: {confidence:.2f}%)" | |
else: | |
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0] | |
label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ" | |
return f"{label} (вероятность: {prediction*100:.2f}%)" | |
iface = gr.Interface( | |
fn=classify_text, | |
inputs=[ | |
gr.Textbox(lines=3, placeholder="Введите текст..."), | |
gr.Radio(list(models.keys()), label="Выберите модель") | |
], | |
outputs="text", | |
title="ruSpamNS - Проверка на спам", | |
description="Введите текст, чтобы проверить, является ли он спамом." | |
) | |
iface.launch() | |