dp92 commited on
Commit
4c30733
·
1 Parent(s): feff62c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -189
app.py CHANGED
@@ -1,189 +1,31 @@
1
- import os
2
-
3
- import dash
4
- import dash_core_components as dcc
5
- import dash_html_components as html
6
- import dash_bootstrap_components as dbc
7
- from dash.dependencies import Input, Output, State
8
-
9
- external_stylesheets = [dbc.themes.BOOTSTRAP]
10
-
11
- #Model dependencies
12
- import numpy as np
13
- from tensorflow.keras.models import load_model
14
- from tensorflow.keras.preprocessing.text import Tokenizer
15
- from tensorflow.keras.preprocessing.sequence import pad_sequences
16
- import pickle
17
- import plotly
18
- import plotly.graph_objects as go
19
-
20
-
21
- app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
22
- server = app.server
23
-
24
- def load_models():
25
- # Load my pre-trained Keras model
26
- global model, tokenizer
27
- model = load_model('model.h5')
28
- # load my original tokenizer used to build model
29
- with open('tokenizer.pkl', 'rb') as f:
30
- tokenizer = pickle.load(f)
31
- return model, tokenizer
32
-
33
- model, tokenizer = load_models()
34
-
35
- def prepare_text(text):
36
- '''Need to Tokenize text and pad sequences'''
37
- words = tokenizer.texts_to_sequences(text)
38
- words_padded = pad_sequences(words, maxlen = 150)
39
-
40
- return words_padded
41
-
42
- # Initialize plot
43
- fig = go.Figure(
44
- data=[go.Bar(x=["Toxic", "Severe toxic", "Obscene", "Threat", "Insult", "Identity hate"],
45
- y=[0, 0, 0, 0, 0, 0],
46
- marker=dict(color="#673EF1"),
47
- width=[0.4]*6)],
48
- layout=go.Layout(
49
- title=go.layout.Title(text="Your comment is...", font=dict(size=20)),
50
- plot_bgcolor='rgba(0,0,0,0)',
51
- font=dict(
52
- family='Source Sans Pro'
53
- ),
54
- xaxis=dict(
55
- tickfont=dict(size=16)
56
- ),
57
- yaxis=dict(
58
- title_text="Probability",
59
- tickmode="array",
60
- tickfont=dict(size=16),
61
- range=[0, 1]
62
- ),
63
- margin=dict(
64
- b=100,
65
- t=30,
66
- ),
67
- title_x=0.5
68
- )
69
- )
70
-
71
- # Layout of the app
72
- app.layout = html.Div(children=[
73
- html.Div(children=[
74
- html.Img(
75
- src='assets/eyes.gif',
76
- style={
77
- 'width': '200px'
78
- })
79
- ], style={'textAlign': 'center', 'backgroundColor': '#000000'}),
80
- html.H1("The Toxic Comment Agent"),
81
-
82
- html.Div("Discussing things you care about can be difficult. The threat of abuse and harassment online means that many people stop expressing themselves and give up on seeking different opinions. This tool uses a multi-label model that can detect the type of toxicity of a comment.",
83
- style={'textAlign': 'center', 'margin-bottom': '40px', 'margin-left': '300px', 'margin-right': '300px'}),
84
-
85
- html.H2("Enter a comment below and I'll predict what I think about it", style={'fontSize': 24}),
86
- html.Div(children=[
87
- html.Img(
88
- src='assets/arrow.gif',
89
- style={
90
- 'width': '50px'
91
- })
92
- ], style={'textAlign': 'center', 'margin-bottom': '20px'}),
93
-
94
- # Block of the text area
95
- html.Div(children=[
96
- dcc.Textarea(id = 'comment', value = '', style={'width': '50%', 'rows': '5'})
97
- ], style={'textAlign': 'center'}),
98
-
99
- # Check button
100
- html.Div(children=[
101
- html.Button(id = 'submit-button', n_clicks = 0, children = 'Check', className='button-submit')
102
- ], style={'textAlign': 'center', 'fontSize': 22, 'height': '100px', 'margin-bottom': '0px'}),
103
-
104
- # Display graph
105
- dcc.Graph(id='update-chart', figure=fig, style={
106
- 'height': 300,
107
- 'width': 700,
108
- "display": "block",
109
- "margin-left": "auto",
110
- "margin-right": "auto",
111
- }),
112
-
113
- html.Div([
114
- # About this project
115
- html.H2("About this project", style={'fontSize': 28, 'color': '#673EF1'}),
116
- html.Div('This predictive tool was built as part of a student project during our Post Master degree in Big Data at Télécom Paris. The data used to build the tool are from the Kaggle "Toxic Comment Classification Challenge" organized by Jigsaw and Conversation AI.',
117
- style={'textAlign': 'center', 'margin-left': '300px', 'margin-right': '300px', 'margin-top': '30px', 'margin-bottom': '30px'}),
118
- html.Div(children=[
119
- html.A([html.Img(src='assets/github-icon.png', style={'width': '30px'})], href='https://github.com/camillecochener/Toxic_comment_classification_challenge')
120
- ],
121
- style={'textAlign': 'center', 'margin-bottom': '40px'}),
122
-
123
- # About us
124
- html.H2("About us", style={'fontSize': 28, 'color': '#673EF1'}),
125
- html.Div(children=[
126
- dbc.Row([
127
- dbc.Col(html.Img(src='assets/camille.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
128
- dbc.Col(html.Img(src='assets/hamza.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
129
- dbc.Col(html.Img(src='assets/sophie.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'})),
130
- dbc.Col(html.Img(src='assets/rodolphe.png', style={'width': '100px', 'margin-left': '10px', 'margin-right': '10px'}))
131
- ]),
132
- dbc.Row([
133
- dbc.Col(html.Div("Camille COCHENER", style={'margin-left': '10px', 'margin-right': '10px'})),
134
- dbc.Col(html.Div("Hamza AMRI", style={'margin-left': '10px', 'margin-right': '10px'})),
135
- dbc.Col(html.Div("Sophie LEVEUGLE", style={'margin-left': '10px', 'margin-right': '10px'})),
136
- dbc.Col(html.Div("Rodolphe SIMONEAU", style={'margin-left': '10px', 'margin-right': '10px'})),
137
- ])
138
- ], style={'textAlign': 'center', 'height': '100px', 'margin-left':'300px', 'margin-right': '300px', 'margin-top': '30px', 'margin-bottom': '30px'}),
139
- html.Div('© Copyright TheToxicCommentAgent', style={'textAlign': 'center', 'margin-top': '40px'})
140
- ], style={'backgroundColor': '#F6F6F6'})
141
-
142
- ])
143
-
144
- @app.callback(Output('update-chart', 'figure'),
145
- [Input('submit-button', 'n_clicks')],
146
- [State('comment', 'value')])
147
- def predict_text(submit, comment):
148
- if comment is '':
149
- empty_fig = go.Figure(
150
- data=[go.Bar(x=["Toxic", "Severe toxic", "Obscene", "Threat", "Insult", "Identity hate"],
151
- y=[0, 0, 0, 0, 0, 0],
152
- marker=dict(color="#673EF1"),
153
- width=[0.4]*6)],
154
- layout=go.Layout(
155
- title=go.layout.Title(text="Your comment is...", font=dict(size=20)),
156
- plot_bgcolor='rgba(0,0,0,0)',
157
- font=dict(
158
- family='Source Sans Pro'
159
- ),
160
- xaxis=dict(
161
- tickfont=dict(size=16)
162
- ),
163
- yaxis=dict(
164
- title_text="Probability",
165
- tickmode="array",
166
- tickfont=dict(size=16),
167
- range=[0, 1]
168
- ),
169
- margin=dict(
170
- b=100,
171
- t=30,
172
- ),
173
- title_x=0.5
174
- )
175
- )
176
- return empty_fig
177
- else:
178
- try:
179
- clean_text = prepare_text([comment])
180
- preds = model.predict(clean_text)
181
- print(preds[0])
182
- yvalue = [i for i in preds[0]]
183
- return fig.update_traces(y=yvalue)
184
- except ValueError as e:
185
- print(e)
186
- return "The text area is empty."
187
-
188
- if __name__ == '__main__':
189
- app.run_server(debug=False, threaded = False)
 
1
+ !pip install transformers torch
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ # Load the pretrained BERT model and tokenizer
7
+ model_name = 'bert-base-uncased'
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)
10
+
11
+ # Define the labels and their corresponding indices
12
+ labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
13
+ label2id = {label: i for i, label in enumerate(labels)}
14
+
15
+ # Define a function to preprocess the text input
16
+ def preprocess(text):
17
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt')
18
+ return inputs['input_ids'], inputs['attention_mask']
19
+
20
+ # Define a function to classify a text input
21
+ def classify(text):
22
+ input_ids, attention_mask = preprocess(text)
23
+ with torch.no_grad():
24
+ logits = model(input_ids, attention_mask=attention_mask).logits
25
+ preds = torch.sigmoid(logits) > 0.5
26
+ return [labels[i] for i, pred in enumerate(preds.squeeze().tolist()) if pred]
27
+
28
+ # Example usage
29
+ text = "You are a stupid idiot"
30
+ preds = classify(text)
31
+ print(preds) # Output: ['toxic', 'insult']