File size: 2,274 Bytes
4950ce7
 
 
 
f04015a
ff2438f
4950ce7
f04015a
4950ce7
67d2259
 
 
f86ae42
 
67d2259
 
 
 
4950ce7
f04015a
 
 
ff2438f
b7a5d14
ff2438f
 
 
 
f04015a
a9b4211
67d2259
 
a9b4211
f04015a
 
 
 
67d2259
4950ce7
f04015a
 
67d2259
f04015a
 
4950ce7
 
 
a9b4211
 
f86ae42
a9b4211
4950ce7
 
 
 
 
a9b4211
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
import re
import emoji

TOKEN = os.getenv("HF_TOKEN")

models = {
    "ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
    "ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
    "ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
    "ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14"  # Добавлена новая модель
}

tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def clean_text(text):
    text = emoji.replace_emoji(text, replace='')
    text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE)
    text = text.lower()
    text = text.capitalize()
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def classify_text(text, model_choice):
    tokenizer = tokenizers[model_choice]
    model = models[model_choice]
    
    message = clean_text(text)
    encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask).logits
        prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
    
    label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
    return f"{label} (вероятность: {prediction*100:.2f}%)"

iface = gr.Interface(
    fn=classify_text,
    inputs=[
        gr.Textbox(lines=3, placeholder="Введите текст..."),
        gr.Radio(["ruSpamNS_v13", "ruSpamNS_big", "ruSpamNS_small", "ruSpamNS_v14"], label="Выберите модель")  # Добавлена новая модель в выбор
    ],
    outputs="text",
    title="ruSpamNS - Проверка на спам",
    description="Введите текст, чтобы проверить, является ли он спамом."
)

iface.launch()