Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import pickle
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# --- Parameters ---
|
8 |
+
# Threshold for minimum probability to classify an image
|
9 |
+
threshold = 0.65
|
10 |
+
|
11 |
+
# --- Load the Trained Model ---
|
12 |
+
# This block loads the pre-trained model from a pickle file.
|
13 |
+
# Ensure 'model_trained.p' is in the same directory as this script.
|
14 |
+
try:
|
15 |
+
with open("model_trained.p", "rb") as pickle_in:
|
16 |
+
model = pickle.load(pickle_in)
|
17 |
+
print("Model loaded successfully.")
|
18 |
+
except FileNotFoundError:
|
19 |
+
print("Error: 'model_trained.p' not found. Please ensure the model file is in the correct directory.")
|
20 |
+
model = None
|
21 |
+
except Exception as e:
|
22 |
+
print(f"An error occurred while loading the model: {e}")
|
23 |
+
model = None
|
24 |
+
|
25 |
+
# --- Preprocessing Function ---
|
26 |
+
def preProcessing(img):
|
27 |
+
"""
|
28 |
+
Converts an image to grayscale, applies histogram equalization,
|
29 |
+
and normalizes the pixel values.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img (numpy.ndarray): The input image in BGR format.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
numpy.ndarray: The preprocessed image.
|
36 |
+
"""
|
37 |
+
# Convert image to grayscale
|
38 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
39 |
+
# Equalize the histogram of the grayscale image to improve contrast
|
40 |
+
img = cv2.equalizeHist(img)
|
41 |
+
# Normalize pixel values to be between 0 and 1
|
42 |
+
img = img / 255.0
|
43 |
+
return img
|
44 |
+
|
45 |
+
# --- Prediction Function for Live Feed ---
|
46 |
+
def predict(img):
|
47 |
+
"""
|
48 |
+
Takes a single frame from the webcam feed, preprocesses it,
|
49 |
+
and predicts the class using the loaded model. It then annotates
|
50 |
+
the frame with the prediction and probability.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
img (numpy.ndarray): The input frame from the Gradio webcam component (in RGB format).
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
numpy.ndarray: The frame annotated with the prediction details (in RGB format).
|
57 |
+
"""
|
58 |
+
if model is None:
|
59 |
+
# If the model isn't loaded, return the frame without any text
|
60 |
+
# and add an error message.
|
61 |
+
cv2.putText(img, "MODEL NOT LOADED", (20, 40), cv2.FONT_HERSHEY_COMPLEX,
|
62 |
+
1, (0, 0, 255), 2)
|
63 |
+
return img
|
64 |
+
|
65 |
+
# Gradio provides the image as an RGB numpy array.
|
66 |
+
# OpenCV uses BGR, so we need to convert the color space.
|
67 |
+
img_original = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
68 |
+
|
69 |
+
# Prepare the image for the model
|
70 |
+
img_resized = cv2.resize(img_original, (32, 32))
|
71 |
+
img_processed = preProcessing(img_resized)
|
72 |
+
|
73 |
+
# Reshape the image to match the model's expected input shape
|
74 |
+
img_reshaped = img_processed.reshape(1, 32, 32, 1)
|
75 |
+
|
76 |
+
# --- Make Predictions ---
|
77 |
+
# Get the raw prediction probabilities for each class
|
78 |
+
predictions = model.predict(img_reshaped)
|
79 |
+
# Find the class index with the highest probability
|
80 |
+
class_index = np.argmax(predictions)
|
81 |
+
# Get the highest probability value
|
82 |
+
prob_val = np.amax(predictions)
|
83 |
+
|
84 |
+
# --- Display the Result on the Image ---
|
85 |
+
# If the probability is higher than the set threshold, annotate the image
|
86 |
+
if prob_val > threshold:
|
87 |
+
# Prepare the text to be displayed
|
88 |
+
prediction_text = f"Class: {class_index}"
|
89 |
+
probability_text = f"Prob: {prob_val:.2f}"
|
90 |
+
|
91 |
+
# Add the text to the original image frame
|
92 |
+
cv2.putText(img_original, prediction_text, (20, 40), cv2.FONT_HERSHEY_COMPLEX,
|
93 |
+
1, (0, 255, 0), 2)
|
94 |
+
cv2.putText(img_original, probability_text, (20, 80), cv2.FONT_HERSHEY_COMPLEX,
|
95 |
+
1, (0, 255, 0), 2)
|
96 |
+
else:
|
97 |
+
# If probability is below threshold, indicate that
|
98 |
+
cv2.putText(img_original, "No certain prediction", (20, 40), cv2.FONT_HERSHEY_COMPLEX,
|
99 |
+
1, (0, 0, 255), 2)
|
100 |
+
|
101 |
+
# Convert the annotated BGR frame back to RGB for display in Gradio
|
102 |
+
img_display = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
|
103 |
+
|
104 |
+
return img_display
|
105 |
+
|
106 |
+
# --- Create and Launch the Gradio Interface ---
|
107 |
+
iface = gr.Interface(
|
108 |
+
fn=predict,
|
109 |
+
inputs=gr.Image(sources="webcam", streaming=True, label="Live Webcam Feed"),
|
110 |
+
outputs=gr.Image(label="Result"),
|
111 |
+
live=True,
|
112 |
+
title="Live Webcam Image Classifier",
|
113 |
+
description="This application uses your webcam for real-time image classification. The model's prediction will be overlaid on the video feed.",
|
114 |
+
)
|
115 |
+
|
116 |
+
# Launch the web interface
|
117 |
+
if __name__ == "__main__":
|
118 |
+
if model is not None:
|
119 |
+
iface.launch()
|
120 |
+
else:
|
121 |
+
print("Cannot launch the application because the model failed to load.")
|