Spaces:
Sleeping
Sleeping
updated the dataset handling
Browse files
app.py
CHANGED
@@ -5,59 +5,62 @@ import matplotlib.pyplot as plt
|
|
5 |
import pandas as pd
|
6 |
import seaborn as sns
|
7 |
|
8 |
-
# Load dataset from Hugging Face
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
-
|
13 |
split = "train"
|
14 |
data = dataset[split]
|
15 |
-
label_names = dataset[split].features["label"].names
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def display_images(images, labels, label_names, cols=5):
|
18 |
-
"""
|
19 |
-
rows = (len(images) + cols - 1) // cols
|
20 |
-
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
|
21 |
-
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
|
|
|
22 |
for i, ax in enumerate(axes):
|
23 |
if i < len(images):
|
24 |
ax.imshow(images[i])
|
25 |
ax.set_title(label_names[labels[i]])
|
26 |
ax.axis("off")
|
27 |
else:
|
28 |
-
ax.axis("off")
|
29 |
-
|
30 |
st.pyplot(fig)
|
31 |
|
32 |
-
|
33 |
def main():
|
34 |
-
# Streamlit app
|
35 |
st.title("Image Dataset Explorer")
|
36 |
st.subheader(f"Displaying images from the {split} set")
|
37 |
-
|
38 |
-
#Initial Images
|
39 |
if st.button("Show First 25 Images"):
|
40 |
with st.spinner("Loading images..."):
|
41 |
-
images =
|
42 |
-
labels =
|
43 |
display_images(images, labels, label_names)
|
44 |
|
45 |
-
# Additional Exploration Options
|
46 |
st.sidebar.title("Explore the Dataset")
|
47 |
|
48 |
# Random Image Viewer
|
49 |
if st.sidebar.button("Show Random Images"):
|
50 |
with st.spinner("Loading images..."):
|
51 |
-
rand_indices = np.random.choice(len(data), 25, replace=False)
|
52 |
-
|
53 |
-
|
54 |
-
display_images(
|
55 |
|
56 |
# Class Distribution
|
57 |
if st.sidebar.button("Show Class Distribution"):
|
58 |
-
label_counts =
|
59 |
label_names_map = {i: name for i, name in enumerate(label_names)}
|
60 |
-
|
61 |
fig, ax = plt.subplots(figsize=(8, 4))
|
62 |
sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax)
|
63 |
ax.set_title("Class Distribution")
|
@@ -69,14 +72,13 @@ def main():
|
|
69 |
selected_label = st.sidebar.selectbox("Filter by Label", label_names)
|
70 |
if st.sidebar.button("Show Filtered Images"):
|
71 |
with st.spinner("Loading images..."):
|
72 |
-
filtered_indices = [i for i
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
display_images(
|
77 |
else:
|
78 |
st.write(f"No images found for label: {selected_label}")
|
79 |
|
80 |
-
|
81 |
if __name__ == "__main__":
|
82 |
-
main()
|
|
|
5 |
import pandas as pd
|
6 |
import seaborn as sns
|
7 |
|
8 |
+
# Load dataset from Hugging Face with caching
|
9 |
+
def load_hf_dataset():
|
10 |
+
repo_id = "louiecerv/cats_dogs_dataset"
|
11 |
+
return load_dataset(repo_id)
|
12 |
|
13 |
+
dataset = load_hf_dataset()
|
14 |
split = "train"
|
15 |
data = dataset[split]
|
16 |
+
label_names = dataset[split].features["label"].names
|
17 |
+
|
18 |
+
@st.cache_data
|
19 |
+
def get_label_counts():
|
20 |
+
return pd.Series(data["label"]).value_counts().sort_index()
|
21 |
|
22 |
def display_images(images, labels, label_names, cols=5):
|
23 |
+
"""Display images in a grid."""
|
24 |
+
rows = (len(images) + cols - 1) // cols
|
25 |
+
fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
|
26 |
+
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
|
27 |
+
|
28 |
for i, ax in enumerate(axes):
|
29 |
if i < len(images):
|
30 |
ax.imshow(images[i])
|
31 |
ax.set_title(label_names[labels[i]])
|
32 |
ax.axis("off")
|
33 |
else:
|
34 |
+
ax.axis("off")
|
35 |
+
|
36 |
st.pyplot(fig)
|
37 |
|
|
|
38 |
def main():
|
|
|
39 |
st.title("Image Dataset Explorer")
|
40 |
st.subheader(f"Displaying images from the {split} set")
|
41 |
+
|
42 |
+
# Show Initial Images
|
43 |
if st.button("Show First 25 Images"):
|
44 |
with st.spinner("Loading images..."):
|
45 |
+
images = [data[i]["image"] for i in range(25)]
|
46 |
+
labels = [data[i]["label"] for i in range(25)]
|
47 |
display_images(images, labels, label_names)
|
48 |
|
|
|
49 |
st.sidebar.title("Explore the Dataset")
|
50 |
|
51 |
# Random Image Viewer
|
52 |
if st.sidebar.button("Show Random Images"):
|
53 |
with st.spinner("Loading images..."):
|
54 |
+
rand_indices = [int(i) for i in np.random.choice(len(data), 25, replace=False)]
|
55 |
+
images = [data[i]["image"] for i in rand_indices]
|
56 |
+
labels = [data[i]["label"] for i in rand_indices]
|
57 |
+
display_images(images, labels, label_names)
|
58 |
|
59 |
# Class Distribution
|
60 |
if st.sidebar.button("Show Class Distribution"):
|
61 |
+
label_counts = get_label_counts()
|
62 |
label_names_map = {i: name for i, name in enumerate(label_names)}
|
63 |
+
|
64 |
fig, ax = plt.subplots(figsize=(8, 4))
|
65 |
sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax)
|
66 |
ax.set_title("Class Distribution")
|
|
|
72 |
selected_label = st.sidebar.selectbox("Filter by Label", label_names)
|
73 |
if st.sidebar.button("Show Filtered Images"):
|
74 |
with st.spinner("Loading images..."):
|
75 |
+
filtered_indices = [int(i) for i in np.where(np.array(data["label"]) == label_names.index(selected_label))[0]]
|
76 |
+
if len(filtered_indices) > 0:
|
77 |
+
images = [data[i]["image"] for i in filtered_indices]
|
78 |
+
labels = [data[i]["label"] for i in filtered_indices]
|
79 |
+
display_images(images, labels, label_names)
|
80 |
else:
|
81 |
st.write(f"No images found for label: {selected_label}")
|
82 |
|
|
|
83 |
if __name__ == "__main__":
|
84 |
+
main()
|