OptiDigit / Home.py
Sathwikchowdary's picture
Update Home.py
d500bfb verified
raw
history blame
2.95 kB
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from tensorflow.keras.models import load_model
import numpy as np
import cv2
# App configuration
st.set_page_config(page_title="DigitSketch - AI Digit Classifier", layout="centered")
# Custom styling with CSS
st.markdown("""
<style>
.stApp {
background-color: #121212;
color: #f0f0f0;
}
h1 {
color: #00ffff;
text-align: center;
text-shadow: 1px 1px 8px #00ffff;
}
.digit-result {
text-align: center;
font-size: 2.5em;
font-weight: bold;
color: #ff4d4d;
text-shadow: 1px 1px 10px #ff4d4d;
margin-top: 20px;
}
.canvas-title {
text-align: center;
color: #80dfff;
font-size: 1.2em;
margin-bottom: 10px;
}
</style>
""", unsafe_allow_html=True)
# App title and description
st.title("๐ŸŽจ DigitSketch: AI Handwritten Digit Classifier")
st.markdown("Draw a digit between **0โ€“9** below, then click **๐Ÿ”ฎ Predict** to see what the AI thinks it is!")
# Sidebar: Drawing settings
st.sidebar.header("๐Ÿ› ๏ธ Drawing Controls")
drawing_mode = st.sidebar.selectbox("Choose a drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
stroke_width = st.sidebar.slider("Pen thickness", 1, 25, 10)
stroke_color = st.sidebar.color_picker("Pen color", "#FFFFFF")
bg_color = st.sidebar.color_picker("Canvas background", "#000000")
realtime_update = st.sidebar.checkbox("Live update", True)
# Load the trained model
@st.cache_resource
def load_digit_model():
return load_model("digit_reco.keras")
model = load_digit_model()
# Canvas drawing area
st.markdown('<div class="canvas-title">โœ๏ธ Draw your digit below</div>', unsafe_allow_html=True)
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0.0)", # Transparent fill
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
update_streamlit=realtime_update,
height=280,
width=280,
drawing_mode=drawing_mode,
key="canvas",
)
# Predict Button
if st.button("๐Ÿ”ฎ Predict"):
if canvas_result.image_data is not None:
st.image(canvas_result.image_data, caption="๐Ÿ–ผ๏ธ Your Drawing", use_container_width=True)
# Image preprocessing
img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
img = 255 - img # Invert for white digit on black
img_resized = cv2.resize(img, (28, 28))
img_normalized = img_resized / 255.0
img_reshaped = img_normalized.reshape((1, 28, 28))
# Predict
prediction = model.predict(img_reshaped)
predicted_digit = np.argmax(prediction)
# Display result
st.markdown(f'<div class="digit-result">Predicted Digit: {predicted_digit}</div>', unsafe_allow_html=True)
else:
st.warning("โš ๏ธ Please draw something before clicking Predict.")