File size: 3,143 Bytes
4950ce7
 
 
 
f04015a
ff2438f
4950ce7
f04015a
4950ce7
67d2259
 
 
f86ae42
414c376
c272398
65e0ec6
6568a98
caf3792
67d2259
 
 
 
4950ce7
f04015a
 
 
ff2438f
b7a5d14
ff2438f
 
 
 
f04015a
a9b4211
67d2259
 
a9b4211
f04015a
 
 
 
67d2259
4950ce7
f04015a
6568a98
414c376
caf3792
c391082
 
 
414c376
 
 
 
 
 
 
 
4950ce7
 
 
a9b4211
 
414c376
a9b4211
4950ce7
 
 
 
 
a9b4211
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
import re
import emoji

TOKEN = os.getenv("HF_TOKEN")

models = {
    "ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
    "ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
    "ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
    "ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14",
    "ruSpamNS_v14_multiclass": "NeuroSpaceX/ruSpamNS_v14_multiclass",
    "ruSpamNS_v16_multiclass": "NeuroSpaceX/ruSpamNS_v16_multiclass",
    "ruSpamNS_v17_multiclass": "NeuroSpaceX/ruSpamNS_v17_multiclass",
    "ruSpamNS_v19_multiclass": "NeuroSpaceX/ruSpamNS_v19_multiclass"
}

tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def clean_text(text):
    text = emoji.replace_emoji(text, replace='')
    text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE)
    text = text.lower()
    text = text.capitalize()
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def classify_text(text, model_choice):
    tokenizer = tokenizers[model_choice]
    model = models[model_choice]
    
    message = clean_text(text)
    encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask).logits
        if "multiclass" in model_choice:
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0]
            if model_choice == "ruSpamNS_v19_multiclass":
                labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ/ТОВАРЫ", "ВАКАНСИИ", "РЕКЛАМА/УСЛУГИ"]
            else:
                labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ", "ВАКАНСИИ", "ПРОДАЖА"]
            predicted_index = probabilities.argmax()
            predicted_label = labels[predicted_index]
            confidence = probabilities[predicted_index] * 100
            return f"{predicted_label} (вероятность: {confidence:.2f}%)"
        else:
            prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
            label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
            return f"{label} (вероятность: {prediction*100:.2f}%)"

iface = gr.Interface(
    fn=classify_text,
    inputs=[
        gr.Textbox(lines=3, placeholder="Введите текст..."),
        gr.Radio(list(models.keys()), label="Выберите модель")
    ],
    outputs="text",
    title="ruSpamNS - Проверка на спам",
    description="Введите текст, чтобы проверить, является ли он спамом."
)

iface.launch()