Spaces:
Sleeping
Sleeping
File size: 2,274 Bytes
4950ce7 f04015a ff2438f 4950ce7 f04015a 4950ce7 67d2259 f86ae42 67d2259 4950ce7 f04015a ff2438f b7a5d14 ff2438f f04015a a9b4211 67d2259 a9b4211 f04015a 67d2259 4950ce7 f04015a 67d2259 f04015a 4950ce7 a9b4211 f86ae42 a9b4211 4950ce7 a9b4211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
import re
import emoji
TOKEN = os.getenv("HF_TOKEN")
models = {
"ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
"ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
"ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
"ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14" # Добавлена новая модель
}
tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def clean_text(text):
text = emoji.replace_emoji(text, replace='')
text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE)
text = text.lower()
text = text.capitalize()
text = re.sub(r'\s+', ' ', text).strip()
return text
def classify_text(text, model_choice):
tokenizer = tokenizers[model_choice]
model = models[model_choice]
message = clean_text(text)
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask).logits
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
return f"{label} (вероятность: {prediction*100:.2f}%)"
iface = gr.Interface(
fn=classify_text,
inputs=[
gr.Textbox(lines=3, placeholder="Введите текст..."),
gr.Radio(["ruSpamNS_v13", "ruSpamNS_big", "ruSpamNS_small", "ruSpamNS_v14"], label="Выберите модель") # Добавлена новая модель в выбор
],
outputs="text",
title="ruSpamNS - Проверка на спам",
description="Введите текст, чтобы проверить, является ли он спамом."
)
iface.launch()
|