Doom008 commited on
Commit
e10ecbb
·
verified ·
1 Parent(s): df0636c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import pickle
5
+ from PIL import Image
6
+
7
+ # --- Parameters ---
8
+ # Threshold for minimum probability to classify an image
9
+ threshold = 0.65
10
+
11
+ # --- Load the Trained Model ---
12
+ # This block loads the pre-trained model from a pickle file.
13
+ # Ensure 'model_trained.p' is in the same directory as this script.
14
+ try:
15
+ with open("model_trained.p", "rb") as pickle_in:
16
+ model = pickle.load(pickle_in)
17
+ print("Model loaded successfully.")
18
+ except FileNotFoundError:
19
+ print("Error: 'model_trained.p' not found. Please ensure the model file is in the correct directory.")
20
+ model = None
21
+ except Exception as e:
22
+ print(f"An error occurred while loading the model: {e}")
23
+ model = None
24
+
25
+ # --- Preprocessing Function ---
26
+ def preProcessing(img):
27
+ """
28
+ Converts an image to grayscale, applies histogram equalization,
29
+ and normalizes the pixel values.
30
+
31
+ Args:
32
+ img (numpy.ndarray): The input image in BGR format.
33
+
34
+ Returns:
35
+ numpy.ndarray: The preprocessed image.
36
+ """
37
+ # Convert image to grayscale
38
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
39
+ # Equalize the histogram of the grayscale image to improve contrast
40
+ img = cv2.equalizeHist(img)
41
+ # Normalize pixel values to be between 0 and 1
42
+ img = img / 255.0
43
+ return img
44
+
45
+ # --- Prediction Function for Live Feed ---
46
+ def predict(img):
47
+ """
48
+ Takes a single frame from the webcam feed, preprocesses it,
49
+ and predicts the class using the loaded model. It then annotates
50
+ the frame with the prediction and probability.
51
+
52
+ Args:
53
+ img (numpy.ndarray): The input frame from the Gradio webcam component (in RGB format).
54
+
55
+ Returns:
56
+ numpy.ndarray: The frame annotated with the prediction details (in RGB format).
57
+ """
58
+ if model is None:
59
+ # If the model isn't loaded, return the frame without any text
60
+ # and add an error message.
61
+ cv2.putText(img, "MODEL NOT LOADED", (20, 40), cv2.FONT_HERSHEY_COMPLEX,
62
+ 1, (0, 0, 255), 2)
63
+ return img
64
+
65
+ # Gradio provides the image as an RGB numpy array.
66
+ # OpenCV uses BGR, so we need to convert the color space.
67
+ img_original = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
68
+
69
+ # Prepare the image for the model
70
+ img_resized = cv2.resize(img_original, (32, 32))
71
+ img_processed = preProcessing(img_resized)
72
+
73
+ # Reshape the image to match the model's expected input shape
74
+ img_reshaped = img_processed.reshape(1, 32, 32, 1)
75
+
76
+ # --- Make Predictions ---
77
+ # Get the raw prediction probabilities for each class
78
+ predictions = model.predict(img_reshaped)
79
+ # Find the class index with the highest probability
80
+ class_index = np.argmax(predictions)
81
+ # Get the highest probability value
82
+ prob_val = np.amax(predictions)
83
+
84
+ # --- Display the Result on the Image ---
85
+ # If the probability is higher than the set threshold, annotate the image
86
+ if prob_val > threshold:
87
+ # Prepare the text to be displayed
88
+ prediction_text = f"Class: {class_index}"
89
+ probability_text = f"Prob: {prob_val:.2f}"
90
+
91
+ # Add the text to the original image frame
92
+ cv2.putText(img_original, prediction_text, (20, 40), cv2.FONT_HERSHEY_COMPLEX,
93
+ 1, (0, 255, 0), 2)
94
+ cv2.putText(img_original, probability_text, (20, 80), cv2.FONT_HERSHEY_COMPLEX,
95
+ 1, (0, 255, 0), 2)
96
+ else:
97
+ # If probability is below threshold, indicate that
98
+ cv2.putText(img_original, "No certain prediction", (20, 40), cv2.FONT_HERSHEY_COMPLEX,
99
+ 1, (0, 0, 255), 2)
100
+
101
+ # Convert the annotated BGR frame back to RGB for display in Gradio
102
+ img_display = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
103
+
104
+ return img_display
105
+
106
+ # --- Create and Launch the Gradio Interface ---
107
+ iface = gr.Interface(
108
+ fn=predict,
109
+ inputs=gr.Image(sources="webcam", streaming=True, label="Live Webcam Feed"),
110
+ outputs=gr.Image(label="Result"),
111
+ live=True,
112
+ title="Live Webcam Image Classifier",
113
+ description="This application uses your webcam for real-time image classification. The model's prediction will be overlaid on the video feed.",
114
+ )
115
+
116
+ # Launch the web interface
117
+ if __name__ == "__main__":
118
+ if model is not None:
119
+ iface.launch()
120
+ else:
121
+ print("Cannot launch the application because the model failed to load.")