CineAI commited on
Commit
89f3202
·
verified ·
1 Parent(s): 608c7ca

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +127 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,129 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from ultralytics import YOLO
3
+ from PIL import Image
4
+ import io
5
 
6
+ # --- Page Configuration ---
7
+ st.set_page_config(
8
+ page_title="YOLO Object Detection",
9
+ page_icon="🤖",
10
+ layout="wide",
11
+ initial_sidebar_state="expanded",
12
+ )
13
+
14
+ # --- Custom CSS for Styling ---
15
+ st.markdown("""
16
+ <style>
17
+ /* Main app background */
18
+ .stApp {
19
+ background-color: #f0f2f6;
20
+ }
21
+ /* Sidebar styling */
22
+ .css-1d391kg {
23
+ background-color: #ffffff;
24
+ }
25
+ /* Title styling */
26
+ h1 {
27
+ color: #1E3A8A; /* A deep blue color */
28
+ text-align: center;
29
+ }
30
+ /* Header styling */
31
+ h2, h3 {
32
+ color: #3B82F6; /* A lighter blue */
33
+ }
34
+ /* Expander styling */
35
+ .st-expander {
36
+ border: 1px solid #ddd;
37
+ border-radius: 10px;
38
+ background-color: #ffffff;
39
+ }
40
+ </style>
41
+ """, unsafe_allow_html=True)
42
+
43
+ # --- Model Loading ---
44
+ # Use st.cache_resource to load the model only once
45
+ @st.cache_resource
46
+ def load_model(model_path):
47
+ """
48
+ Loads the YOLO model from the specified path.
49
+ Caches the model to avoid reloading on every interaction.
50
+ """
51
+ try:
52
+ model = YOLO(model_path)
53
+ return model
54
+ except Exception as e:
55
+ st.error(f"Error loading model: {e}")
56
+ return None
57
+
58
+ # Path to your custom model
59
+ MODEL_PATH = "rssi_last.pt"
60
+ model = load_model(MODEL_PATH)
61
+
62
+ # --- Sidebar ---
63
+ st.sidebar.header("Configuration")
64
+ confidence_threshold = st.sidebar.slider(
65
+ "Confidence Threshold", 0.0, 1.0, 0.4, 0.05
66
+ )
67
+ st.sidebar.markdown("---")
68
+ uploaded_file = st.sidebar.file_uploader(
69
+ "Upload an image...", type=["jpg", "jpeg", "png"]
70
+ )
71
+
72
+ st.sidebar.markdown("---")
73
+ st.sidebar.markdown(
74
+ "**About this App**\n\n"
75
+ "This application uses a custom-trained YOLO model to detect objects in images. "
76
+ "Upload an image and see the magic!"
77
+ )
78
+
79
+ # --- Main Page ---
80
+ st.title("🖼️ Custom Object Detection with YOLO")
81
+ st.write("Upload an image via the sidebar to see the model's predictions.")
82
+
83
+ if uploaded_file is not None:
84
+ # Read the uploaded image file
85
+ image_data = uploaded_file.getvalue()
86
+ original_image = Image.open(io.BytesIO(image_data))
87
+
88
+ # Create two columns for side-by-side display
89
+ col1, col2 = st.columns(2)
90
+
91
+ with col1:
92
+ st.subheader("Original Image")
93
+ st.image(original_image, caption="Your uploaded image.", use_column_width=True)
94
+
95
+ if model:
96
+ # Perform inference
97
+ with st.spinner("Running detection..."):
98
+ results = model(original_image, conf=confidence_threshold)
99
+
100
+ # The result object contains the annotated image and detection data
101
+ result = results[0]
102
+
103
+ # Use the plot() method to get an annotated image (in BGR format)
104
+ annotated_image_bgr = result.plot()
105
+ # Convert BGR to RGB for display in Streamlit
106
+ annotated_image_rgb = annotated_image_bgr[..., ::-1]
107
+
108
+ with col2:
109
+ st.subheader("Detected Objects")
110
+ st.image(annotated_image_rgb, caption="Image with detected objects.", use_column_width=True)
111
+
112
+ # Display detection details
113
+ st.subheader("Detection Details")
114
+ if len(result.boxes) > 0:
115
+ with st.expander("Click to see detailed results", expanded=True):
116
+ # Extract details for each detected box
117
+ for i, box in enumerate(result.boxes):
118
+ label = result.names[box.cls[0].item()]
119
+ conf = box.conf[0].item()
120
+ xywhn = box.xywhn[0].tolist() # Normalized xywh
121
+
122
+ st.markdown(f"**Object {i+1}: `{label}`**")
123
+ st.write(f"- Confidence: **{conf:.2f}**")
124
+ st.write(f"- Bounding Box (Normalized xywh):")
125
+ st.code(f" x: {xywhn[0]:.4f}, y: {xywhn[1]:.4f}, width: {xywhn[2]:.4f}, height: {xywhn[3]:.4f}")
126
+ else:
127
+ st.info("No objects were detected with the current confidence threshold.")
128
+ else:
129
+ st.info("Please upload an image using the sidebar to begin.")