File size: 2,104 Bytes
4950ce7
 
 
 
f04015a
ff2438f
4950ce7
f04015a
4950ce7
67d2259
 
 
 
 
 
 
 
4950ce7
f04015a
 
 
ff2438f
b7a5d14
ff2438f
 
 
 
f04015a
a9b4211
67d2259
 
a9b4211
f04015a
 
 
 
67d2259
4950ce7
f04015a
 
67d2259
f04015a
 
4950ce7
 
 
a9b4211
 
67d2259
a9b4211
4950ce7
 
 
 
 
a9b4211
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
import re
import emoji

TOKEN = os.getenv("HF_TOKEN")

models = {
    "ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
    "ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
    "ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small"
}

tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def clean_text(text):
    text = emoji.replace_emoji(text, replace='')
    text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE)
    text = text.lower()
    text = text.capitalize()
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def classify_text(text, model_choice):
    tokenizer = tokenizers[model_choice]
    model = models[model_choice]
    
    message = clean_text(text)
    encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask).logits
        prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
    
    label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
    return f"{label} (вероятность: {prediction*100:.2f}%)"

iface = gr.Interface(
    fn=classify_text,
    inputs=[
        gr.Textbox(lines=3, placeholder="Введите текст..."),
        gr.Radio(["ruSpamNS_v13", "ruSpamNS_big", "ruSpamNS_small"], label="Выберите модель")
    ],
    outputs="text",
    title="ruSpamNS - Проверка на спам",
    description="Введите текст, чтобы проверить, является ли он спамом."
)

iface.launch()