NeuroSpaceX commited on
Commit
f04015a
·
verified ·
1 Parent(s): 4950ce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -2,19 +2,34 @@ import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
  import os
 
5
 
6
  MODEL_NAME = "NeuroSpaceX/ruSpamNS"
7
- TOKEN = os.getenv("HF_TOKEN") # Читаем токен из переменной окружения
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
10
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
11
 
 
 
 
 
 
 
 
 
 
 
12
  def classify_text(text):
13
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
 
 
14
  with torch.no_grad():
15
- outputs = model(**inputs)
16
- prediction = torch.argmax(outputs.logits, dim=1).item()
17
- return "СПАМ" if prediction == 1 else "НЕ СПАМ"
 
18
 
19
  iface = gr.Interface(
20
  fn=classify_text,
@@ -24,4 +39,4 @@ iface = gr.Interface(
24
  description="Введите текст, чтобы проверить, является ли он спамом."
25
  )
26
 
27
- iface.launch()
 
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
  import os
5
+ import re
6
 
7
  MODEL_NAME = "NeuroSpaceX/ruSpamNS"
8
+ TOKEN = os.getenv("HF_TOKEN")
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
11
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, use_auth_token=TOKEN)
12
 
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ model.to(device)
15
+
16
+ def clean_text(text):
17
+ text = text.strip()
18
+ text = text.replace('\n', ' ')
19
+ text = re.sub(r'[^\w\s,.!?]', '', text, flags=re.UNICODE)
20
+ text = re.sub(r'[!?]', '', text)
21
+ return text.lower()
22
+
23
  def classify_text(text):
24
+ message = clean_text(text)
25
+ encoding = tokenizer(message, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
26
+ input_ids = encoding['input_ids'].to(device)
27
+ attention_mask = encoding['attention_mask'].to(device)
28
  with torch.no_grad():
29
+ outputs = model(input_ids, attention_mask=attention_mask).logits
30
+ prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
31
+ label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
32
+ return f"{label} (вероятность: {prediction*100:.2f}%)"
33
 
34
  iface = gr.Interface(
35
  fn=classify_text,
 
39
  description="Введите текст, чтобы проверить, является ли он спамом."
40
  )
41
 
42
+ iface.launch()