eybro commited on
Commit
d4c93a0
·
verified ·
1 Parent(s): 7e60582

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -7,11 +7,18 @@ from keras.models import Model
7
  from datasets import load_dataset
8
  from sklearn.cluster import KMeans
9
  import matplotlib.pyplot as plt
 
10
 
11
  autoencoder = load_model("autoencoder_model.keras")
12
  encoded_images = np.load("X_encoded_compressed.npy")
13
- print("Shape of encoded_images:", encoded_images.shape)
14
- print("Sample encoded image:", encoded_images[0])
 
 
 
 
 
 
15
 
16
 
17
  num_clusters = 10 # Choose the number of clusters
@@ -38,11 +45,6 @@ def find_nearest_neighbors(encoded_images, input_image, top_n=5):
38
  return [(index, distances[index]) for index in nearest_neighbors]
39
 
40
  def get_image(index):
41
- dataset = load_dataset('eybro/images')
42
- split_dataset = dataset['train'].train_test_split(test_size=0.2)
43
- dataset['train'] = split_dataset['train']
44
- dataset['test'] = split_dataset['test']
45
- split = len(dataset["train"])
46
  if index < split:
47
  return dataset["train"][index]
48
  else:
@@ -80,7 +82,6 @@ def inference(image):
80
  print(f"top 4: {top4}")
81
 
82
  for i in top4:
83
- print(i)
84
  im = get_image(i)
85
  print(im["label"], im["timestamp"])
86
 
 
7
  from datasets import load_dataset
8
  from sklearn.cluster import KMeans
9
  import matplotlib.pyplot as plt
10
+ import pickle
11
 
12
  autoencoder = load_model("autoencoder_model.keras")
13
  encoded_images = np.load("X_encoded_compressed.npy")
14
+
15
+
16
+ # Load the dataset with the split index
17
+ with open("dataset_with_split.pkl", "rb") as f:
18
+ data = pickle.load(f)
19
+
20
+ dataset = data['dataset']
21
+ split = data['split']
22
 
23
 
24
  num_clusters = 10 # Choose the number of clusters
 
45
  return [(index, distances[index]) for index in nearest_neighbors]
46
 
47
  def get_image(index):
 
 
 
 
 
48
  if index < split:
49
  return dataset["train"][index]
50
  else:
 
82
  print(f"top 4: {top4}")
83
 
84
  for i in top4:
 
85
  im = get_image(i)
86
  print(im["label"], im["timestamp"])
87