Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,189 +1,31 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
model, tokenizer = load_models()
|
34 |
-
|
35 |
-
def prepare_text(text):
|
36 |
-
'''Need to Tokenize text and pad sequences'''
|
37 |
-
words = tokenizer.texts_to_sequences(text)
|
38 |
-
words_padded = pad_sequences(words, maxlen = 150)
|
39 |
-
|
40 |
-
return words_padded
|
41 |
-
|
42 |
-
# Initialize plot
|
43 |
-
fig = go.Figure(
|
44 |
-
data=[go.Bar(x=["Toxic", "Severe toxic", "Obscene", "Threat", "Insult", "Identity hate"],
|
45 |
-
y=[0, 0, 0, 0, 0, 0],
|
46 |
-
marker=dict(color="#673EF1"),
|
47 |
-
width=[0.4]*6)],
|
48 |
-
layout=go.Layout(
|
49 |
-
title=go.layout.Title(text="Your comment is...", font=dict(size=20)),
|
50 |
-
plot_bgcolor='rgba(0,0,0,0)',
|
51 |
-
font=dict(
|
52 |
-
family='Source Sans Pro'
|
53 |
-
),
|
54 |
-
xaxis=dict(
|
55 |
-
tickfont=dict(size=16)
|
56 |
-
),
|
57 |
-
yaxis=dict(
|
58 |
-
title_text="Probability",
|
59 |
-
tickmode="array",
|
60 |
-
tickfont=dict(size=16),
|
61 |
-
range=[0, 1]
|
62 |
-
),
|
63 |
-
margin=dict(
|
64 |
-
b=100,
|
65 |
-
t=30,
|
66 |
-
),
|
67 |
-
title_x=0.5
|
68 |
-
)
|
69 |
-
)
|
70 |
-
|
71 |
-
# Layout of the app
|
72 |
-
app.layout = html.Div(children=[
|
73 |
-
html.Div(children=[
|
74 |
-
html.Img(
|
75 |
-
src='assets/eyes.gif',
|
76 |
-
style={
|
77 |
-
'width': '200px'
|
78 |
-
})
|
79 |
-
], style={'textAlign': 'center', 'backgroundColor': '#000000'}),
|
80 |
-
html.H1("The Toxic Comment Agent"),
|
81 |
-
|
82 |
-
html.Div("Discussing things you care about can be difficult. The threat of abuse and harassment online means that many people stop expressing themselves and give up on seeking different opinions. This tool uses a multi-label model that can detect the type of toxicity of a comment.",
|
83 |
-
style={'textAlign': 'center', 'margin-bottom': '40px', 'margin-left': '300px', 'margin-right': '300px'}),
|
84 |
-
|
85 |
-
html.H2("Enter a comment below and I'll predict what I think about it", style={'fontSize': 24}),
|
86 |
-
html.Div(children=[
|
87 |
-
html.Img(
|
88 |
-
src='assets/arrow.gif',
|
89 |
-
style={
|
90 |
-
'width': '50px'
|
91 |
-
})
|
92 |
-
], style={'textAlign': 'center', 'margin-bottom': '20px'}),
|
93 |
-
|
94 |
-
# Block of the text area
|
95 |
-
html.Div(children=[
|
96 |
-
dcc.Textarea(id = 'comment', value = '', style={'width': '50%', 'rows': '5'})
|
97 |
-
], style={'textAlign': 'center'}),
|
98 |
-
|
99 |
-
# Check button
|
100 |
-
html.Div(children=[
|
101 |
-
html.Button(id = 'submit-button', n_clicks = 0, children = 'Check', className='button-submit')
|
102 |
-
], style={'textAlign': 'center', 'fontSize': 22, 'height': '100px', 'margin-bottom': '0px'}),
|
103 |
-
|
104 |
-
# Display graph
|
105 |
-
dcc.Graph(id='update-chart', figure=fig, style={
|
106 |
-
'height': 300,
|
107 |
-
'width': 700,
|
108 |
-
"display": "block",
|
109 |
-
"margin-left": "auto",
|
110 |
-
"margin-right": "auto",
|
111 |
-
}),
|
112 |
-
|
113 |
-
html.Div([
|
114 |
-
# About this project
|
115 |
-
html.H2("About this project", style={'fontSize': 28, 'color': '#673EF1'}),
|
116 |
-
html.Div('This predictive tool was built as part of a student project during our Post Master degree in Big Data at Télécom Paris. The data used to build the tool are from the Kaggle "Toxic Comment Classification Challenge" organized by Jigsaw and Conversation AI.',
|
117 |
-
style={'textAlign': 'center', 'margin-left': '300px', 'margin-right': '300px', 'margin-top': '30px', 'margin-bottom': '30px'}),
|
118 |
-
html.Div(children=[
|
119 |
-
html.A([html.Img(src='assets/github-icon.png', style={'width': '30px'})], href='https://github.com/camillecochener/Toxic_comment_classification_challenge')
|
120 |
-
],
|
121 |
-
style={'textAlign': 'center', 'margin-bottom': '40px'}),
|
122 |
-
|
123 |
-
# About us
|
124 |
-
html.H2("About us", style={'fontSize': 28, 'color': '#673EF1'}),
|
125 |
-
html.Div(children=[
|
126 |
-
dbc.Row([
|
127 |
-
dbc.Col(html.Img(src='assets/camille.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
|
128 |
-
dbc.Col(html.Img(src='assets/hamza.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
|
129 |
-
dbc.Col(html.Img(src='assets/sophie.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
|
130 |
-
dbc.Col(html.Img(src='assets/rodolphe.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'}))
|
131 |
-
]),
|
132 |
-
dbc.Row([
|
133 |
-
dbc.Col(html.Div("Camille COCHENER", style={'margin-left': '10px', 'margin-right': '10px'})),
|
134 |
-
dbc.Col(html.Div("Hamza AMRI", style={'margin-left': '10px', 'margin-right': '10px'})),
|
135 |
-
dbc.Col(html.Div("Sophie LEVEUGLE", style={'margin-left': '10px', 'margin-right': '10px'})),
|
136 |
-
dbc.Col(html.Div("Rodolphe SIMONEAU", style={'margin-left': '10px', 'margin-right': '10px'})),
|
137 |
-
])
|
138 |
-
], style={'textAlign': 'center', 'height': '100px', 'margin-left':'300px', 'margin-right': '300px', 'margin-top': '30px', 'margin-bottom': '30px'}),
|
139 |
-
html.Div('© Copyright TheToxicCommentAgent', style={'textAlign': 'center', 'margin-top': '40px'})
|
140 |
-
], style={'backgroundColor': '#F6F6F6'})
|
141 |
-
|
142 |
-
])
|
143 |
-
|
144 |
-
@app.callback(Output('update-chart', 'figure'),
|
145 |
-
[Input('submit-button', 'n_clicks')],
|
146 |
-
[State('comment', 'value')])
|
147 |
-
def predict_text(submit, comment):
|
148 |
-
if comment is '':
|
149 |
-
empty_fig = go.Figure(
|
150 |
-
data=[go.Bar(x=["Toxic", "Severe toxic", "Obscene", "Threat", "Insult", "Identity hate"],
|
151 |
-
y=[0, 0, 0, 0, 0, 0],
|
152 |
-
marker=dict(color="#673EF1"),
|
153 |
-
width=[0.4]*6)],
|
154 |
-
layout=go.Layout(
|
155 |
-
title=go.layout.Title(text="Your comment is...", font=dict(size=20)),
|
156 |
-
plot_bgcolor='rgba(0,0,0,0)',
|
157 |
-
font=dict(
|
158 |
-
family='Source Sans Pro'
|
159 |
-
),
|
160 |
-
xaxis=dict(
|
161 |
-
tickfont=dict(size=16)
|
162 |
-
),
|
163 |
-
yaxis=dict(
|
164 |
-
title_text="Probability",
|
165 |
-
tickmode="array",
|
166 |
-
tickfont=dict(size=16),
|
167 |
-
range=[0, 1]
|
168 |
-
),
|
169 |
-
margin=dict(
|
170 |
-
b=100,
|
171 |
-
t=30,
|
172 |
-
),
|
173 |
-
title_x=0.5
|
174 |
-
)
|
175 |
-
)
|
176 |
-
return empty_fig
|
177 |
-
else:
|
178 |
-
try:
|
179 |
-
clean_text = prepare_text([comment])
|
180 |
-
preds = model.predict(clean_text)
|
181 |
-
print(preds[0])
|
182 |
-
yvalue = [i for i in preds[0]]
|
183 |
-
return fig.update_traces(y=yvalue)
|
184 |
-
except ValueError as e:
|
185 |
-
print(e)
|
186 |
-
return "The text area is empty."
|
187 |
-
|
188 |
-
if __name__ == '__main__':
|
189 |
-
app.run_server(debug=False, threaded = False)
|
|
|
1 |
+
!pip install transformers torch
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
# Load the pretrained BERT model and tokenizer
|
7 |
+
model_name = 'bert-base-uncased'
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)
|
10 |
+
|
11 |
+
# Define the labels and their corresponding indices
|
12 |
+
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
|
13 |
+
label2id = {label: i for i, label in enumerate(labels)}
|
14 |
+
|
15 |
+
# Define a function to preprocess the text input
|
16 |
+
def preprocess(text):
|
17 |
+
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
|
18 |
+
return inputs['input_ids'], inputs['attention_mask']
|
19 |
+
|
20 |
+
# Define a function to classify a text input
|
21 |
+
def classify(text):
|
22 |
+
input_ids, attention_mask = preprocess(text)
|
23 |
+
with torch.no_grad():
|
24 |
+
logits = model(input_ids, attention_mask=attention_mask).logits
|
25 |
+
preds = torch.sigmoid(logits) > 0.5
|
26 |
+
return [labels[i] for i, pred in enumerate(preds.squeeze().tolist()) if pred]
|
27 |
+
|
28 |
+
# Example usage
|
29 |
+
text = "You are a stupid idiot"
|
30 |
+
preds = classify(text)
|
31 |
+
print(preds) # Output: ['toxic', 'insult']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|