louiecerv commited on
Commit
3761521
·
1 Parent(s): 42a99a5

updated the dataset handling

Browse files
Files changed (1) hide show
  1. app.py +33 -31
app.py CHANGED
@@ -5,59 +5,62 @@ import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import seaborn as sns
7
 
8
- # Load dataset from Hugging Face
9
- repo_id = "louiecerv/cats_dogs_dataset"
10
- dataset = load_dataset(repo_id)
 
11
 
12
- # Select the split (train set by default)
13
  split = "train"
14
  data = dataset[split]
15
- label_names = dataset[split].features["label"].names # Get class names upfront
 
 
 
 
16
 
17
  def display_images(images, labels, label_names, cols=5):
18
- """Helper function to display images in a grid."""
19
- rows = (len(images) + cols - 1) // cols # Calculate rows dynamically
20
- fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows)) #Adjust figure size dynamically
21
- axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] #Handle single plot case
 
22
  for i, ax in enumerate(axes):
23
  if i < len(images):
24
  ax.imshow(images[i])
25
  ax.set_title(label_names[labels[i]])
26
  ax.axis("off")
27
  else:
28
- ax.axis("off") # Turn off unused subplots
29
-
30
  st.pyplot(fig)
31
 
32
-
33
  def main():
34
- # Streamlit app
35
  st.title("Image Dataset Explorer")
36
  st.subheader(f"Displaying images from the {split} set")
37
-
38
- #Initial Images
39
  if st.button("Show First 25 Images"):
40
  with st.spinner("Loading images..."):
41
- images = np.array(data["image"][:25])
42
- labels = np.array(data["label"][:25])
43
  display_images(images, labels, label_names)
44
 
45
- # Additional Exploration Options
46
  st.sidebar.title("Explore the Dataset")
47
 
48
  # Random Image Viewer
49
  if st.sidebar.button("Show Random Images"):
50
  with st.spinner("Loading images..."):
51
- rand_indices = np.random.choice(len(data), 25, replace=False)
52
- rand_images = np.array(data["image"])[rand_indices]
53
- rand_labels = np.array(data["label"])[rand_indices]
54
- display_images(rand_images, rand_labels, label_names)
55
 
56
  # Class Distribution
57
  if st.sidebar.button("Show Class Distribution"):
58
- label_counts = pd.Series(data["label"]).value_counts().sort_index() #Use entire dataset for distribution
59
  label_names_map = {i: name for i, name in enumerate(label_names)}
60
-
61
  fig, ax = plt.subplots(figsize=(8, 4))
62
  sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax)
63
  ax.set_title("Class Distribution")
@@ -69,14 +72,13 @@ def main():
69
  selected_label = st.sidebar.selectbox("Filter by Label", label_names)
70
  if st.sidebar.button("Show Filtered Images"):
71
  with st.spinner("Loading images..."):
72
- filtered_indices = [i for i, label in enumerate(data["label"]) if label_names[label] == selected_label]
73
- filtered_images = np.array(data["image"])[filtered_indices] #No need to slice if displaying all found
74
- filtered_labels = np.array(data["label"])[filtered_indices]
75
- if len(filtered_images) > 0: #check for no images found
76
- display_images(filtered_images, filtered_labels, label_names)
77
  else:
78
  st.write(f"No images found for label: {selected_label}")
79
 
80
-
81
  if __name__ == "__main__":
82
- main()
 
5
  import pandas as pd
6
  import seaborn as sns
7
 
8
+ # Load dataset from Hugging Face with caching
9
+ def load_hf_dataset():
10
+ repo_id = "louiecerv/cats_dogs_dataset"
11
+ return load_dataset(repo_id)
12
 
13
+ dataset = load_hf_dataset()
14
  split = "train"
15
  data = dataset[split]
16
+ label_names = dataset[split].features["label"].names
17
+
18
+ @st.cache_data
19
+ def get_label_counts():
20
+ return pd.Series(data["label"]).value_counts().sort_index()
21
 
22
  def display_images(images, labels, label_names, cols=5):
23
+ """Display images in a grid."""
24
+ rows = (len(images) + cols - 1) // cols
25
+ fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
26
+ axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
27
+
28
  for i, ax in enumerate(axes):
29
  if i < len(images):
30
  ax.imshow(images[i])
31
  ax.set_title(label_names[labels[i]])
32
  ax.axis("off")
33
  else:
34
+ ax.axis("off")
35
+
36
  st.pyplot(fig)
37
 
 
38
  def main():
 
39
  st.title("Image Dataset Explorer")
40
  st.subheader(f"Displaying images from the {split} set")
41
+
42
+ # Show Initial Images
43
  if st.button("Show First 25 Images"):
44
  with st.spinner("Loading images..."):
45
+ images = [data[i]["image"] for i in range(25)]
46
+ labels = [data[i]["label"] for i in range(25)]
47
  display_images(images, labels, label_names)
48
 
 
49
  st.sidebar.title("Explore the Dataset")
50
 
51
  # Random Image Viewer
52
  if st.sidebar.button("Show Random Images"):
53
  with st.spinner("Loading images..."):
54
+ rand_indices = [int(i) for i in np.random.choice(len(data), 25, replace=False)]
55
+ images = [data[i]["image"] for i in rand_indices]
56
+ labels = [data[i]["label"] for i in rand_indices]
57
+ display_images(images, labels, label_names)
58
 
59
  # Class Distribution
60
  if st.sidebar.button("Show Class Distribution"):
61
+ label_counts = get_label_counts()
62
  label_names_map = {i: name for i, name in enumerate(label_names)}
63
+
64
  fig, ax = plt.subplots(figsize=(8, 4))
65
  sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax)
66
  ax.set_title("Class Distribution")
 
72
  selected_label = st.sidebar.selectbox("Filter by Label", label_names)
73
  if st.sidebar.button("Show Filtered Images"):
74
  with st.spinner("Loading images..."):
75
+ filtered_indices = [int(i) for i in np.where(np.array(data["label"]) == label_names.index(selected_label))[0]]
76
+ if len(filtered_indices) > 0:
77
+ images = [data[i]["image"] for i in filtered_indices]
78
+ labels = [data[i]["label"] for i in filtered_indices]
79
+ display_images(images, labels, label_names)
80
  else:
81
  st.write(f"No images found for label: {selected_label}")
82
 
 
83
  if __name__ == "__main__":
84
+ main()