louiecerv's picture
updated the dataset handling
3761521
import streamlit as st
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
# Load dataset from Hugging Face with caching
def load_hf_dataset():
repo_id = "louiecerv/cats_dogs_dataset"
return load_dataset(repo_id)
dataset = load_hf_dataset()
split = "train"
data = dataset[split]
label_names = dataset[split].features["label"].names
@st.cache_data
def get_label_counts():
return pd.Series(data["label"]).value_counts().sort_index()
def display_images(images, labels, label_names, cols=5):
"""Display images in a grid."""
rows = (len(images) + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
for i, ax in enumerate(axes):
if i < len(images):
ax.imshow(images[i])
ax.set_title(label_names[labels[i]])
ax.axis("off")
else:
ax.axis("off")
st.pyplot(fig)
def main():
st.title("Image Dataset Explorer")
st.subheader(f"Displaying images from the {split} set")
# Show Initial Images
if st.button("Show First 25 Images"):
with st.spinner("Loading images..."):
images = [data[i]["image"] for i in range(25)]
labels = [data[i]["label"] for i in range(25)]
display_images(images, labels, label_names)
st.sidebar.title("Explore the Dataset")
# Random Image Viewer
if st.sidebar.button("Show Random Images"):
with st.spinner("Loading images..."):
rand_indices = [int(i) for i in np.random.choice(len(data), 25, replace=False)]
images = [data[i]["image"] for i in rand_indices]
labels = [data[i]["label"] for i in rand_indices]
display_images(images, labels, label_names)
# Class Distribution
if st.sidebar.button("Show Class Distribution"):
label_counts = get_label_counts()
label_names_map = {i: name for i, name in enumerate(label_names)}
fig, ax = plt.subplots(figsize=(8, 4))
sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax)
ax.set_title("Class Distribution")
ax.set_ylabel("Count")
ax.set_xlabel("Class")
st.pyplot(fig)
# Filter by class label
selected_label = st.sidebar.selectbox("Filter by Label", label_names)
if st.sidebar.button("Show Filtered Images"):
with st.spinner("Loading images..."):
filtered_indices = [int(i) for i in np.where(np.array(data["label"]) == label_names.index(selected_label))[0]]
if len(filtered_indices) > 0:
images = [data[i]["image"] for i in filtered_indices]
labels = [data[i]["label"] for i in filtered_indices]
display_images(images, labels, label_names)
else:
st.write(f"No images found for label: {selected_label}")
if __name__ == "__main__":
main()