|
import streamlit as st |
|
from datasets import load_dataset |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import seaborn as sns |
|
|
|
|
|
def load_hf_dataset(): |
|
repo_id = "louiecerv/cats_dogs_dataset" |
|
return load_dataset(repo_id) |
|
|
|
dataset = load_hf_dataset() |
|
split = "train" |
|
data = dataset[split] |
|
label_names = dataset[split].features["label"].names |
|
|
|
@st.cache_data |
|
def get_label_counts(): |
|
return pd.Series(data["label"]).value_counts().sort_index() |
|
|
|
def display_images(images, labels, label_names, cols=5): |
|
"""Display images in a grid.""" |
|
rows = (len(images) + cols - 1) // cols |
|
fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows)) |
|
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] |
|
|
|
for i, ax in enumerate(axes): |
|
if i < len(images): |
|
ax.imshow(images[i]) |
|
ax.set_title(label_names[labels[i]]) |
|
ax.axis("off") |
|
else: |
|
ax.axis("off") |
|
|
|
st.pyplot(fig) |
|
|
|
def main(): |
|
st.title("Image Dataset Explorer") |
|
st.subheader(f"Displaying images from the {split} set") |
|
|
|
|
|
if st.button("Show First 25 Images"): |
|
with st.spinner("Loading images..."): |
|
images = [data[i]["image"] for i in range(25)] |
|
labels = [data[i]["label"] for i in range(25)] |
|
display_images(images, labels, label_names) |
|
|
|
st.sidebar.title("Explore the Dataset") |
|
|
|
|
|
if st.sidebar.button("Show Random Images"): |
|
with st.spinner("Loading images..."): |
|
rand_indices = [int(i) for i in np.random.choice(len(data), 25, replace=False)] |
|
images = [data[i]["image"] for i in rand_indices] |
|
labels = [data[i]["label"] for i in rand_indices] |
|
display_images(images, labels, label_names) |
|
|
|
|
|
if st.sidebar.button("Show Class Distribution"): |
|
label_counts = get_label_counts() |
|
label_names_map = {i: name for i, name in enumerate(label_names)} |
|
|
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax) |
|
ax.set_title("Class Distribution") |
|
ax.set_ylabel("Count") |
|
ax.set_xlabel("Class") |
|
st.pyplot(fig) |
|
|
|
|
|
selected_label = st.sidebar.selectbox("Filter by Label", label_names) |
|
if st.sidebar.button("Show Filtered Images"): |
|
with st.spinner("Loading images..."): |
|
filtered_indices = [int(i) for i in np.where(np.array(data["label"]) == label_names.index(selected_label))[0]] |
|
if len(filtered_indices) > 0: |
|
images = [data[i]["image"] for i in filtered_indices] |
|
labels = [data[i]["label"] for i in filtered_indices] |
|
display_images(images, labels, label_names) |
|
else: |
|
st.write(f"No images found for label: {selected_label}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|