ChinmayBH commited on
Commit
1fa0f1f
·
verified ·
1 Parent(s): 7f47f8b

added confidences to predictions

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -19,7 +19,7 @@ def train_iris_model(algorithm):
19
  if algorithm == 'KNN':
20
  model = KNeighborsClassifier()
21
  elif algorithm == 'SVM':
22
- model = SVC()
23
  elif algorithm == "logistic regression":
24
  model = LogisticRegression()
25
  elif algorithm == 'Random Forest':
@@ -28,6 +28,7 @@ def train_iris_model(algorithm):
28
  model = AdaBoostClassifier()
29
  elif algorithm == 'Decision tree':
30
  model = DecisionTreeClassifier()
 
31
  model.fit(X_train, y_train)
32
 
33
  return model
@@ -39,7 +40,12 @@ def predict_iris_species(model, input_data):
39
  # Make predictions using the trained model
40
  prediction = model.predict(input_data)
41
 
42
- return prediction
 
 
 
 
 
43
 
44
  def main():
45
  st.title("Iris Species Prediction App")
@@ -50,7 +56,6 @@ def main():
50
  # Train the model based on user's choice
51
  trained_model = train_iris_model(algorithm)
52
 
53
-
54
  st.sidebar.header("User Input")
55
  sepal_length = st.sidebar.slider("Sepal Length", 0.0, 10.0, 5.0)
56
  sepal_width = st.sidebar.slider("Sepal Width", 0.0, 10.0, 5.0)
@@ -58,7 +63,7 @@ def main():
58
  petal_width = st.sidebar.slider("Petal Width", 0.0, 10.0, 5.0)
59
 
60
  input_values = [sepal_length, sepal_width, petal_length, petal_width]
61
- prediction_result = predict_iris_species(trained_model, input_values)
62
 
63
  species_mapping = {0: 'Iris-setosa', 1: 'Iris-virginica', 2: 'Iris-versicolor'}
64
  predicted_species = species_mapping.get(prediction_result[0], 'Unknown')
@@ -72,6 +77,9 @@ def main():
72
  st.subheader("Prediction:")
73
  st.success(f"Predicted Species: {predicted_species}")
74
 
 
 
 
75
  # Display relevant images based on prediction
76
  if predicted_species == 'Iris-setosa':
77
  st.image('setosa_image.jpg', caption='Iris-setosa', use_column_width=True)
 
19
  if algorithm == 'KNN':
20
  model = KNeighborsClassifier()
21
  elif algorithm == 'SVM':
22
+ model = SVC(probability=True)
23
  elif algorithm == "logistic regression":
24
  model = LogisticRegression()
25
  elif algorithm == 'Random Forest':
 
28
  model = AdaBoostClassifier()
29
  elif algorithm == 'Decision tree':
30
  model = DecisionTreeClassifier()
31
+
32
  model.fit(X_train, y_train)
33
 
34
  return model
 
40
  # Make predictions using the trained model
41
  prediction = model.predict(input_data)
42
 
43
+ # Check if the model has a predict_proba method
44
+ if hasattr(model, 'predict_proba'):
45
+ confidence = model.predict_proba(input_data).max()
46
+ return prediction, confidence
47
+ else:
48
+ return prediction, None
49
 
50
  def main():
51
  st.title("Iris Species Prediction App")
 
56
  # Train the model based on user's choice
57
  trained_model = train_iris_model(algorithm)
58
 
 
59
  st.sidebar.header("User Input")
60
  sepal_length = st.sidebar.slider("Sepal Length", 0.0, 10.0, 5.0)
61
  sepal_width = st.sidebar.slider("Sepal Width", 0.0, 10.0, 5.0)
 
63
  petal_width = st.sidebar.slider("Petal Width", 0.0, 10.0, 5.0)
64
 
65
  input_values = [sepal_length, sepal_width, petal_length, petal_width]
66
+ prediction_result, confidence = predict_iris_species(trained_model, input_values)
67
 
68
  species_mapping = {0: 'Iris-setosa', 1: 'Iris-virginica', 2: 'Iris-versicolor'}
69
  predicted_species = species_mapping.get(prediction_result[0], 'Unknown')
 
77
  st.subheader("Prediction:")
78
  st.success(f"Predicted Species: {predicted_species}")
79
 
80
+ if confidence is not None:
81
+ st.info(f"Confidence of prediction: {confidence * 100:.2f}%")
82
+
83
  # Display relevant images based on prediction
84
  if predicted_species == 'Iris-setosa':
85
  st.image('setosa_image.jpg', caption='Iris-setosa', use_column_width=True)