Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,11 +7,18 @@ from keras.models import Model
|
|
7 |
from datasets import load_dataset
|
8 |
from sklearn.cluster import KMeans
|
9 |
import matplotlib.pyplot as plt
|
|
|
10 |
|
11 |
autoencoder = load_model("autoencoder_model.keras")
|
12 |
encoded_images = np.load("X_encoded_compressed.npy")
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
num_clusters = 10 # Choose the number of clusters
|
@@ -38,11 +45,6 @@ def find_nearest_neighbors(encoded_images, input_image, top_n=5):
|
|
38 |
return [(index, distances[index]) for index in nearest_neighbors]
|
39 |
|
40 |
def get_image(index):
|
41 |
-
dataset = load_dataset('eybro/images')
|
42 |
-
split_dataset = dataset['train'].train_test_split(test_size=0.2)
|
43 |
-
dataset['train'] = split_dataset['train']
|
44 |
-
dataset['test'] = split_dataset['test']
|
45 |
-
split = len(dataset["train"])
|
46 |
if index < split:
|
47 |
return dataset["train"][index]
|
48 |
else:
|
@@ -80,7 +82,6 @@ def inference(image):
|
|
80 |
print(f"top 4: {top4}")
|
81 |
|
82 |
for i in top4:
|
83 |
-
print(i)
|
84 |
im = get_image(i)
|
85 |
print(im["label"], im["timestamp"])
|
86 |
|
|
|
7 |
from datasets import load_dataset
|
8 |
from sklearn.cluster import KMeans
|
9 |
import matplotlib.pyplot as plt
|
10 |
+
import pickle
|
11 |
|
12 |
autoencoder = load_model("autoencoder_model.keras")
|
13 |
encoded_images = np.load("X_encoded_compressed.npy")
|
14 |
+
|
15 |
+
|
16 |
+
# Load the dataset with the split index
|
17 |
+
with open("dataset_with_split.pkl", "rb") as f:
|
18 |
+
data = pickle.load(f)
|
19 |
+
|
20 |
+
dataset = data['dataset']
|
21 |
+
split = data['split']
|
22 |
|
23 |
|
24 |
num_clusters = 10 # Choose the number of clusters
|
|
|
45 |
return [(index, distances[index]) for index in nearest_neighbors]
|
46 |
|
47 |
def get_image(index):
|
|
|
|
|
|
|
|
|
|
|
48 |
if index < split:
|
49 |
return dataset["train"][index]
|
50 |
else:
|
|
|
82 |
print(f"top 4: {top4}")
|
83 |
|
84 |
for i in top4:
|
|
|
85 |
im = get_image(i)
|
86 |
print(im["label"], im["timestamp"])
|
87 |
|