Spaces:
Running
Running
File size: 2,104 Bytes
4950ce7 f04015a ff2438f 4950ce7 f04015a 4950ce7 67d2259 4950ce7 f04015a ff2438f b7a5d14 ff2438f f04015a a9b4211 67d2259 a9b4211 f04015a 67d2259 4950ce7 f04015a 67d2259 f04015a 4950ce7 a9b4211 67d2259 a9b4211 4950ce7 a9b4211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
import re
import emoji
TOKEN = os.getenv("HF_TOKEN")
models = {
"ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
"ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
"ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small"
}
tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
models = {name: AutoModelForSequenceClassification.from_pretrained(path, use_auth_token=TOKEN).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for name, path in models.items()}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def clean_text(text):
text = emoji.replace_emoji(text, replace='')
text = re.sub(r'[^a-zA-Zа-яА-ЯёЁ ]', '', text, flags=re.UNICODE)
text = text.lower()
text = text.capitalize()
text = re.sub(r'\s+', ' ', text).strip()
return text
def classify_text(text, model_choice):
tokenizer = tokenizers[model_choice]
model = models[model_choice]
message = clean_text(text)
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask).logits
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
return f"{label} (вероятность: {prediction*100:.2f}%)"
iface = gr.Interface(
fn=classify_text,
inputs=[
gr.Textbox(lines=3, placeholder="Введите текст..."),
gr.Radio(["ruSpamNS_v13", "ruSpamNS_big", "ruSpamNS_small"], label="Выберите модель")
],
outputs="text",
title="ruSpamNS - Проверка на спам",
description="Введите текст, чтобы проверить, является ли он спамом."
)
iface.launch()
|