Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,19 +2,34 @@ import gradio as gr
|
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
3 |
import torch
|
4 |
import os
|
|
|
5 |
|
6 |
MODEL_NAME = "NeuroSpaceX/ruSpamNS"
|
7 |
-
TOKEN = os.getenv("HF_TOKEN")
|
8 |
|
9 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
|
10 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def classify_text(text):
|
13 |
-
|
|
|
|
|
|
|
14 |
with torch.no_grad():
|
15 |
-
outputs = model(
|
16 |
-
|
17 |
-
|
|
|
18 |
|
19 |
iface = gr.Interface(
|
20 |
fn=classify_text,
|
@@ -24,4 +39,4 @@ iface = gr.Interface(
|
|
24 |
description="Введите текст, чтобы проверить, является ли он спамом."
|
25 |
)
|
26 |
|
27 |
-
iface.launch()
|
|
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
3 |
import torch
|
4 |
import os
|
5 |
+
import re
|
6 |
|
7 |
MODEL_NAME = "NeuroSpaceX/ruSpamNS"
|
8 |
+
TOKEN = os.getenv("HF_TOKEN")
|
9 |
|
10 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
|
11 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
|
12 |
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
model.to(device)
|
15 |
+
|
16 |
+
def clean_text(text):
|
17 |
+
text = text.strip()
|
18 |
+
text = text.replace('\n', ' ')
|
19 |
+
text = re.sub(r'[^\w\s,.!?]', '', text, flags=re.UNICODE)
|
20 |
+
text = re.sub(r'[!?]', '', text)
|
21 |
+
return text.lower()
|
22 |
+
|
23 |
def classify_text(text):
|
24 |
+
message = clean_text(text)
|
25 |
+
encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
|
26 |
+
input_ids = encoding['input_ids'].to(device)
|
27 |
+
attention_mask = encoding['attention_mask'].to(device)
|
28 |
with torch.no_grad():
|
29 |
+
outputs = model(input_ids, attention_mask=attention_mask).logits
|
30 |
+
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
|
31 |
+
label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
|
32 |
+
return f"{label} (вероятность: {prediction*100:.2f}%)"
|
33 |
|
34 |
iface = gr.Interface(
|
35 |
fn=classify_text,
|
|
|
39 |
description="Введите текст, чтобы проверить, является ли он спамом."
|
40 |
)
|
41 |
|
42 |
+
iface.launch()
|