NeuroSpaceX commited on
Commit
414c376
·
verified ·
1 Parent(s): f86ae42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -11,7 +11,8 @@ models = {
11
  "ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
12
  "ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
13
  "ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
14
- "ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14" # Добавлена новая модель
 
15
  }
16
 
17
  tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
@@ -38,16 +39,23 @@ def classify_text(text, model_choice):
38
 
39
  with torch.no_grad():
40
  outputs = model(input_ids, attention_mask=attention_mask).logits
41
- prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
42
-
43
- label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
44
- return f"{label} (вероятность: {prediction*100:.2f}%)"
 
 
 
 
 
 
 
45
 
46
  iface = gr.Interface(
47
  fn=classify_text,
48
  inputs=[
49
  gr.Textbox(lines=3, placeholder="Введите текст..."),
50
- gr.Radio(["ruSpamNS_v13", "ruSpamNS_big", "ruSpamNS_small", "ruSpamNS_v14"], label="Выберите модель") # Добавлена новая модель в выбор
51
  ],
52
  outputs="text",
53
  title="ruSpamNS - Проверка на спам",
 
11
  "ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
12
  "ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
13
  "ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
14
+ "ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14",
15
+ "ruSpamNS_v14_multiclass": "NeuroSpaceX/ruSpamNS_v14_multiclass"
16
  }
17
 
18
  tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
 
39
 
40
  with torch.no_grad():
41
  outputs = model(input_ids, attention_mask=attention_mask).logits
42
+ if model_choice == "ruSpamNS_v14_multiclass":
43
+ probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0]
44
+ labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ", "ВАКАНСИИ"]
45
+ predicted_index = probabilities.argmax()
46
+ predicted_label = labels[predicted_index]
47
+ confidence = probabilities[predicted_index] * 100
48
+ return f"{predicted_label} (вероятность: {confidence:.2f}%)"
49
+ else:
50
+ prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
51
+ label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
52
+ return f"{label} (вероятность: {prediction*100:.2f}%)"
53
 
54
  iface = gr.Interface(
55
  fn=classify_text,
56
  inputs=[
57
  gr.Textbox(lines=3, placeholder="Введите текст..."),
58
+ gr.Radio(list(models.keys()), label="Выберите модель")
59
  ],
60
  outputs="text",
61
  title="ruSpamNS - Проверка на спам",