riyageorge commited on
Commit
37bc6d8
·
1 Parent(s): c9a87cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -25,6 +25,28 @@ def classify_image(img, cnn_model):
25
  return "No Tumor"
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Load your RNN SMS spam detection model
30
  rnn_smsspam_model = tf.keras.models.load_model('rnn_smsspam_model.h5')
@@ -107,6 +129,17 @@ def main():
107
  else:
108
  st.write("Please enter some text for prediction")
109
 
 
 
 
 
 
 
 
 
 
 
 
110
  elif model == "LSTM":
111
  st.subheader("SMS Spam Detection")
112
  user_input = st.text_area("Enter a message to classify as 'Spam' or 'Not spam': ")
 
25
  return "No Tumor"
26
 
27
 
28
+ # Load your DNN SMS spam detection model
29
+ dnn_smsspam_model = tf.keras.models.load_model('dnn_smsspam_model.h5')
30
+ # Load the saved tokenizer
31
+ with open('dnn_smsspam_tokenizer.pickle', 'rb') as handle:
32
+ dnn_smsspam_tokenizer = pickle.load(handle)
33
+
34
+ def dnn_predict_message(input_text):
35
+ max_length=20
36
+ # Process input text similarly to training data
37
+ encoded_input = dnn_smsspam_tokenizer.texts_to_sequences([input_text])
38
+ padded_input = tf.keras.preprocessing.sequence.pad_sequences(encoded_input, maxlen=max_length, padding='post')
39
+ # Get the probabilities of being classified as "Spam" for each input
40
+ predictions = dnn_smsspam_model.predict(padded_input)
41
+ # Define a threshold (e.g., 0.5) for classification
42
+ threshold = 0.5
43
+ # Make the predictions based on the threshold for each input
44
+ for prediction in predictions:
45
+ if prediction > threshold:
46
+ return "Spam"
47
+ else:
48
+ return "Not spam"
49
+
50
 
51
  # Load your RNN SMS spam detection model
52
  rnn_smsspam_model = tf.keras.models.load_model('rnn_smsspam_model.h5')
 
129
  else:
130
  st.write("Please enter some text for prediction")
131
 
132
+ elif model == "DNN":
133
+ st.subheader("SMS Spam Detection")
134
+ user_input = st.text_area("Enter a message to classify as 'Spam' or 'Not spam': ")
135
+
136
+ if st.button("Predict"):
137
+ if user_input:
138
+ prediction_result = dnn_predict_message(user_input)
139
+ st.write(f"The message is classified as: {prediction_result}")
140
+ else:
141
+ st.write("Please enter some text for prediction")
142
+
143
  elif model == "LSTM":
144
  st.subheader("SMS Spam Detection")
145
  user_input = st.text_area("Enter a message to classify as 'Spam' or 'Not spam': ")