Spaces:
Sleeping
Sleeping
sync with remote
Browse files
app.py
CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
|
|
2 |
from datasets import load_dataset
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
|
|
|
|
5 |
|
6 |
# Load dataset from Hugging Face
|
7 |
repo_id = "louiecerv/cats_dogs_dataset"
|
@@ -10,53 +12,50 @@ dataset = load_dataset(repo_id)
|
|
10 |
# Select the split (train set by default)
|
11 |
split = "train"
|
12 |
data = dataset[split]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def main():
|
15 |
# Streamlit app
|
16 |
st.title("Image Dataset Explorer")
|
17 |
-
st.subheader(f"Displaying
|
18 |
-
|
19 |
-
# Extract images and labels
|
20 |
-
images = np.array(data["image"][:25]) # First 25 images
|
21 |
-
labels = np.array(data["label"][:25])
|
22 |
-
label_names = dataset[split].features["label"].names # Get class names
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
ax.set_title(label_names[labels[i]])
|
31 |
-
ax.axis("off")
|
32 |
-
|
33 |
-
st.pyplot(fig)
|
34 |
|
35 |
# Additional Exploration Options
|
36 |
st.sidebar.title("Explore the Dataset")
|
37 |
|
38 |
# Random Image Viewer
|
39 |
if st.sidebar.button("Show Random Images"):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
fig.subplots_adjust(hspace=0.5)
|
46 |
-
|
47 |
-
for i, ax in enumerate(axes.flat):
|
48 |
-
ax.imshow(rand_images[i])
|
49 |
-
ax.set_title(label_names[rand_labels[i]])
|
50 |
-
ax.axis("off")
|
51 |
-
|
52 |
-
st.pyplot(fig)
|
53 |
|
54 |
# Class Distribution
|
55 |
if st.sidebar.button("Show Class Distribution"):
|
56 |
-
|
57 |
-
import seaborn as sns
|
58 |
-
|
59 |
-
label_counts = pd.Series(labels).value_counts().sort_index()
|
60 |
label_names_map = {i: name for i, name in enumerate(label_names)}
|
61 |
|
62 |
fig, ax = plt.subplots(figsize=(8, 4))
|
@@ -64,27 +63,20 @@ def main():
|
|
64 |
ax.set_title("Class Distribution")
|
65 |
ax.set_ylabel("Count")
|
66 |
ax.set_xlabel("Class")
|
67 |
-
|
68 |
st.pyplot(fig)
|
69 |
|
70 |
# Filter by class label
|
71 |
selected_label = st.sidebar.selectbox("Filter by Label", label_names)
|
72 |
if st.sidebar.button("Show Filtered Images"):
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
for i, ax in enumerate(axes.flat):
|
80 |
-
if i < len(filtered_images):
|
81 |
-
ax.imshow(filtered_images[i])
|
82 |
-
ax.set_title(selected_label)
|
83 |
-
ax.axis("off")
|
84 |
else:
|
85 |
-
|
86 |
|
87 |
-
st.pyplot(fig)
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
-
main()
|
|
|
2 |
from datasets import load_dataset
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
import pandas as pd
|
6 |
+
import seaborn as sns
|
7 |
|
8 |
# Load dataset from Hugging Face
|
9 |
repo_id = "louiecerv/cats_dogs_dataset"
|
|
|
12 |
# Select the split (train set by default)
|
13 |
split = "train"
|
14 |
data = dataset[split]
|
15 |
+
label_names = dataset[split].features["label"].names # Get class names upfront
|
16 |
+
|
17 |
+
def display_images(images, labels, label_names, cols=5):
|
18 |
+
"""Helper function to display images in a grid."""
|
19 |
+
rows = (len(images) + cols - 1) // cols # Calculate rows dynamically
|
20 |
+
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows)) #Adjust figure size dynamically
|
21 |
+
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] #Handle single plot case
|
22 |
+
for i, ax in enumerate(axes):
|
23 |
+
if i < len(images):
|
24 |
+
ax.imshow(images[i])
|
25 |
+
ax.set_title(label_names[labels[i]])
|
26 |
+
ax.axis("off")
|
27 |
+
else:
|
28 |
+
ax.axis("off") # Turn off unused subplots
|
29 |
+
|
30 |
+
st.pyplot(fig)
|
31 |
+
|
32 |
|
33 |
def main():
|
34 |
# Streamlit app
|
35 |
st.title("Image Dataset Explorer")
|
36 |
+
st.subheader(f"Displaying images from the {split} set")
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
#Initial Images
|
39 |
+
if st.button("Show First 25 Images"):
|
40 |
+
with st.spinner("Loading images..."):
|
41 |
+
images = np.array(data["image"][:25])
|
42 |
+
labels = np.array(data["label"][:25])
|
43 |
+
display_images(images, labels, label_names)
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Additional Exploration Options
|
46 |
st.sidebar.title("Explore the Dataset")
|
47 |
|
48 |
# Random Image Viewer
|
49 |
if st.sidebar.button("Show Random Images"):
|
50 |
+
with st.spinner("Loading images..."):
|
51 |
+
rand_indices = np.random.choice(len(data), 25, replace=False)
|
52 |
+
rand_images = np.array(data["image"])[rand_indices]
|
53 |
+
rand_labels = np.array(data["label"])[rand_indices]
|
54 |
+
display_images(rand_images, rand_labels, label_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Class Distribution
|
57 |
if st.sidebar.button("Show Class Distribution"):
|
58 |
+
label_counts = pd.Series(data["label"]).value_counts().sort_index() #Use entire dataset for distribution
|
|
|
|
|
|
|
59 |
label_names_map = {i: name for i, name in enumerate(label_names)}
|
60 |
|
61 |
fig, ax = plt.subplots(figsize=(8, 4))
|
|
|
63 |
ax.set_title("Class Distribution")
|
64 |
ax.set_ylabel("Count")
|
65 |
ax.set_xlabel("Class")
|
|
|
66 |
st.pyplot(fig)
|
67 |
|
68 |
# Filter by class label
|
69 |
selected_label = st.sidebar.selectbox("Filter by Label", label_names)
|
70 |
if st.sidebar.button("Show Filtered Images"):
|
71 |
+
with st.spinner("Loading images..."):
|
72 |
+
filtered_indices = [i for i, label in enumerate(data["label"]) if label_names[label] == selected_label]
|
73 |
+
filtered_images = np.array(data["image"])[filtered_indices] #No need to slice if displaying all found
|
74 |
+
filtered_labels = np.array(data["label"])[filtered_indices]
|
75 |
+
if len(filtered_images) > 0: #check for no images found
|
76 |
+
display_images(filtered_images, filtered_labels, label_names)
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
+
st.write(f"No images found for label: {selected_label}")
|
79 |
|
|
|
80 |
|
81 |
if __name__ == "__main__":
|
82 |
+
main()
|