NimaKL commited on
Commit
fe72a06
Β·
1 Parent(s): dc010b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -68
app.py CHANGED
@@ -10,71 +10,67 @@ with col1:
10
  st.markdown("Message spam detection tool for Turkish language. Due the small size of the dataset, I decided to go with transformers technology Google BERT. Using the Turkish pre-trained model BERTurk, I imporved the accuracy of the tool by 18 percent compared to the previous model which used fastText.")
11
 
12
 
13
- def predict(new_sentence):
14
- # We need Token IDs and Attention Mask for inference on the new sentence
15
- test_ids = []
16
- test_attention_mask = []
17
-
18
- # Apply the tokenizer
19
- encoding = preprocessing(new_sentence, tokenizer)
20
-
21
- # Extract IDs and Attention Mask
22
- test_ids.append(encoding['input_ids'])
23
- test_attention_mask.append(encoding['attention_mask'])
24
- test_ids = torch.cat(test_ids, dim = 0)
25
- test_attention_mask = torch.cat(test_attention_mask, dim = 0)
26
-
27
- # Forward pass, calculate logit predictions
28
- with torch.no_grad():
29
- output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
30
-
31
- prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal'
32
- pred = 'Predicted Class: '+ prediction
33
- with col2:
34
- st.header(pred)
35
-
36
- with col2:
37
- text = st.text_input("Enter the text you'd like to analyze for spam.")
38
- if text or st.button('Analyze'):
39
- predict(text)
40
-
41
-
42
-
43
-
44
- import torch
45
- import numpy as np
46
-
47
- from transformers import AutoTokenizer
48
- tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-uncased")
49
- from transformers import AutoModel
50
- model = BertForSequenceClassification.from_pretrained("NimaKL/spamd_model")
51
-
52
- token_id = []
53
- attention_masks = []
54
-
55
- def preprocessing(input_text, tokenizer):
56
- '''
57
- Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
58
- - input_ids: list of token ids
59
- - token_type_ids: list of token type ids
60
- - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
61
- '''
62
- return tokenizer.encode_plus(
63
- input_text,
64
- add_special_tokens = True,
65
- max_length = 32,
66
- pad_to_max_length = True,
67
- return_attention_mask = True,
68
- return_tensors = 'pt'
69
- )
70
-
71
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72
- #Used for printing the name if the variables. Removing it will not intrupt the project.
73
- def namestr(obj, namespace):
74
- return [name for name in namespace if namespace[name] is obj]
75
-
76
-
77
-
78
- #st.write('Input', namestr(new_sentence, globals()),': \n', new_sentence)
79
-
80
-
 
10
  st.markdown("Message spam detection tool for Turkish language. Due the small size of the dataset, I decided to go with transformers technology Google BERT. Using the Turkish pre-trained model BERTurk, I imporved the accuracy of the tool by 18 percent compared to the previous model which used fastText.")
11
 
12
 
13
+ if st.button('Load Model'):
14
+ with st.spinner('Wait for it...'):
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+ from transformers import AutoTokenizer
20
+ tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-uncased")
21
+ from transformers import AutoModel
22
+ model = BertForSequenceClassification.from_pretrained("NimaKL/spamd_model")
23
+
24
+ token_id = []
25
+ attention_masks = []
26
+
27
+ def preprocessing(input_text, tokenizer):
28
+ '''
29
+ Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
30
+ - input_ids: list of token ids
31
+ - token_type_ids: list of token type ids
32
+ - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
33
+ '''
34
+ return tokenizer.encode_plus(
35
+ input_text,
36
+ add_special_tokens = True,
37
+ max_length = 32,
38
+ pad_to_max_length = True,
39
+ return_attention_mask = True,
40
+ return_tensors = 'pt'
41
+ )
42
+
43
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+ #Used for printing the name if the variables. Removing it will not intrupt the project.
45
+ def namestr(obj, namespace):
46
+ return [name for name in namespace if namespace[name] is obj]
47
+
48
+ def predict(new_sentence):
49
+ # We need Token IDs and Attention Mask for inference on the new sentence
50
+ test_ids = []
51
+ test_attention_mask = []
52
+
53
+ # Apply the tokenizer
54
+ encoding = preprocessing(new_sentence, tokenizer)
55
+
56
+ # Extract IDs and Attention Mask
57
+ test_ids.append(encoding['input_ids'])
58
+ test_attention_mask.append(encoding['attention_mask'])
59
+ test_ids = torch.cat(test_ids, dim = 0)
60
+ test_attention_mask = torch.cat(test_attention_mask, dim = 0)
61
+
62
+ # Forward pass, calculate logit predictions
63
+ with torch.no_grad():
64
+ output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
65
+
66
+ prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal'
67
+ pred = 'Predicted Class: '+ prediction
68
+ with col2:
69
+ st.header(pred)
70
+
71
+ #st.write('Input', namestr(new_sentence, globals()),': \n', new_sentence)
72
+ with col2:
73
+ text = st.text_input("Enter the text you'd like to analyze for spam.")
74
+ if text or st.button('Analyze'):
75
+ predict(text)
76
+ st.success()