dp92 commited on
Commit
230173f
·
1 Parent(s): 46838b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -17,15 +17,19 @@ def preprocess(text):
17
  inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
18
  return inputs['input_ids'], inputs['attention_mask']
19
 
20
- # Define a function to classify a text input
21
  def classify(text):
22
  input_ids, attention_mask = preprocess(text)
23
  with torch.no_grad():
24
  logits = model(input_ids, attention_mask=attention_mask).logits
25
- preds = torch.sigmoid(logits) > 0.5
26
- return [labels[i] for i, pred in enumerate(preds.squeeze().tolist()) if pred]
27
 
28
- # Example usage
29
- text = "You are a stupid idiot"
30
  preds = classify(text)
31
- print(preds) # Output: ['toxic', 'insult']
 
 
 
 
 
1
+ !pip install transformers torch
2
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
17
  inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
18
  return inputs['input_ids'], inputs['attention_mask']
19
 
20
+ # Define a function to classify a text input and return the predicted categories with probabilities
21
  def classify(text):
22
  input_ids, attention_mask = preprocess(text)
23
  with torch.no_grad():
24
  logits = model(input_ids, attention_mask=attention_mask).logits
25
+ preds = torch.sigmoid(logits).squeeze().tolist()
26
+ return {labels[i]: preds[i] for i in range(len(labels))}
27
 
28
+ # Prompt the user to input text and classify it
29
+ text = input("Enter text to check toxicity: ")
30
  preds = classify(text)
31
+
32
+ # Print the predicted categories with probabilities
33
+ print("Predicted toxicity categories and probabilities:")
34
+ for label, prob in preds.items():
35
+ print(f"{label}: {prob:.2f}")