File size: 1,875 Bytes
f9d050d
1946528
f9d050d
 
 
 
 
 
 
bcaa150
f9d050d
 
1946528
f9d050d
 
 
 
 
 
 
1946528
 
f9d050d
 
 
 
 
 
 
 
 
 
 
 
 
bcaa150
f9d050d
 
 
 
 
 
 
 
 
 
1946528
 
f9d050d
 
 
 
 
 
 
 
 
 
bcaa150
f9d050d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import gradio as gr

model_name = "ifmain/ModerationBERT-En-02"

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=18)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

categories = [
    'harassment', 'harassment_threatening', 'hate', 'hate_threatening', 
    'self_harm', 'self_harm_instructions', 'self_harm_intent', 'sexual', 
    'sexual_minors', 'violence', 'violence_graphic', 'self-harm', 
    'sexual/minors', 'hate/threatening', 'violence/graphic', 
    'self-harm/intent', 'self-harm/instructions', 'harassment/threatening'
]

def predict_moderation(text):
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,
        return_token_type_ids=False,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    
    probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
    category_scores = {categories[i]: float(probs[i]) for i in range(len(categories))}
    
    detected = any(prob > 0.5 for prob in probs)

    return category_scores, str(detected)


iface = gr.Interface(
    fn=predict_moderation,
    inputs=gr.Textbox(label="Enter text"),
    outputs=[
        gr.Label(label="Ratings by category"),
        gr.Label(label="Was a violation detected?")
    ],
    title="Text moderation",
    description="Enter text to check it for content violations (ModerationBERT-En-02 model)."
)

iface.launch()