Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
@@ -17,15 +17,19 @@ def preprocess(text):
|
|
17 |
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
|
18 |
return inputs['input_ids'], inputs['attention_mask']
|
19 |
|
20 |
-
# Define a function to classify a text input
|
21 |
def classify(text):
|
22 |
input_ids, attention_mask = preprocess(text)
|
23 |
with torch.no_grad():
|
24 |
logits = model(input_ids, attention_mask=attention_mask).logits
|
25 |
-
preds = torch.sigmoid(logits)
|
26 |
-
return
|
27 |
|
28 |
-
#
|
29 |
-
text = "
|
30 |
preds = classify(text)
|
31 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
!pip install transformers torch
|
2 |
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
17 |
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
|
18 |
return inputs['input_ids'], inputs['attention_mask']
|
19 |
|
20 |
+
# Define a function to classify a text input and return the predicted categories with probabilities
|
21 |
def classify(text):
|
22 |
input_ids, attention_mask = preprocess(text)
|
23 |
with torch.no_grad():
|
24 |
logits = model(input_ids, attention_mask=attention_mask).logits
|
25 |
+
preds = torch.sigmoid(logits).squeeze().tolist()
|
26 |
+
return {labels[i]: preds[i] for i in range(len(labels))}
|
27 |
|
28 |
+
# Prompt the user to input text and classify it
|
29 |
+
text = input("Enter text to check toxicity: ")
|
30 |
preds = classify(text)
|
31 |
+
|
32 |
+
# Print the predicted categories with probabilities
|
33 |
+
print("Predicted toxicity categories and probabilities:")
|
34 |
+
for label, prob in preds.items():
|
35 |
+
print(f"{label}: {prob:.2f}")
|