Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,8 @@ models = {
|
|
11 |
"ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
|
12 |
"ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
|
13 |
"ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
|
14 |
-
"ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14"
|
|
|
15 |
}
|
16 |
|
17 |
tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
|
@@ -38,16 +39,23 @@ def classify_text(text, model_choice):
|
|
38 |
|
39 |
with torch.no_grad():
|
40 |
outputs = model(input_ids, attention_mask=attention_mask).logits
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
iface = gr.Interface(
|
47 |
fn=classify_text,
|
48 |
inputs=[
|
49 |
gr.Textbox(lines=3, placeholder="Введите текст..."),
|
50 |
-
gr.Radio(
|
51 |
],
|
52 |
outputs="text",
|
53 |
title="ruSpamNS - Проверка на спам",
|
|
|
11 |
"ruSpamNS_v13": "NeuroSpaceX/ruSpamNS_v13",
|
12 |
"ruSpamNS_big": "NeuroSpaceX/ruSpamNS_big",
|
13 |
"ruSpamNS_small": "NeuroSpaceX/ruSpamNS_small",
|
14 |
+
"ruSpamNS_v14": "NeuroSpaceX/ruSpamNS_v14",
|
15 |
+
"ruSpamNS_v14_multiclass": "NeuroSpaceX/ruSpamNS_v14_multiclass"
|
16 |
}
|
17 |
|
18 |
tokenizers = {name: AutoTokenizer.from_pretrained(path, use_auth_token=TOKEN) for name, path in models.items()}
|
|
|
39 |
|
40 |
with torch.no_grad():
|
41 |
outputs = model(input_ids, attention_mask=attention_mask).logits
|
42 |
+
if model_choice == "ruSpamNS_v14_multiclass":
|
43 |
+
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0]
|
44 |
+
labels = ["НЕ СПАМ", "СПАМ", "НЕДВИЖИМОСТЬ", "ВАКАНСИИ"]
|
45 |
+
predicted_index = probabilities.argmax()
|
46 |
+
predicted_label = labels[predicted_index]
|
47 |
+
confidence = probabilities[predicted_index] * 100
|
48 |
+
return f"{predicted_label} (вероятность: {confidence:.2f}%)"
|
49 |
+
else:
|
50 |
+
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
|
51 |
+
label = "СПАМ" if prediction >= 0.5 else "НЕ СПАМ"
|
52 |
+
return f"{label} (вероятность: {prediction*100:.2f}%)"
|
53 |
|
54 |
iface = gr.Interface(
|
55 |
fn=classify_text,
|
56 |
inputs=[
|
57 |
gr.Textbox(lines=3, placeholder="Введите текст..."),
|
58 |
+
gr.Radio(list(models.keys()), label="Выберите модель")
|
59 |
],
|
60 |
outputs="text",
|
61 |
title="ruSpamNS - Проверка на спам",
|