File size: 2,220 Bytes
f9a8213
de4f74d
 
 
 
f9a8213
56cb512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de4f74d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56cb512
f9a8213
b93dddd
c253956
b93dddd
f9a8213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import numpy as np


encoded_images = np.load("X_encoded_compressed.npy")

def find_nearest_neighbors(encoded_images, input_image, top_n=5):
    """
    Find the closest neighbors to the input image in the encoded image space.

    Args:
    encoded_images (np.ndarray): Array of encoded images (shape: (n_samples, n_features)).
    input_image (np.ndarray): The encoded input image (shape: (1, n_features)).
    top_n (int): The number of nearest neighbors to return.

    Returns:
    List of tuples: (index, distance) of the top_n nearest neighbors.
    """
    # Compute pairwise distances
    distances = euclidean_distances(encoded_images, input_image.reshape(1, -1)).flatten()

    # Sort by distance
    nearest_neighbors = np.argsort(distances)[:top_n]
    return [(index, distances[index]) for index in nearest_neighbors]

def get_image(index):
  split = len(dataset["train"])
  if index < split:
    return dataset["train"][index]
  else:
    return dataset["test"][index-split]

def process_image(image):
    pass

def inference(image):

    input_image = process_image(image)
    
    nearest_neighbors = find_nearest_neighbors(encoded_images, input_image, top_n=5)
    
    # Print the results
    print("Nearest neighbors (index, distance):")
    for neighbor in nearest_neighbors:
        print(neighbor)
    
    top4 = [int(i[0]) for i in nearest_neighbors[:4]]
    print(f"top 4: {top4}")
    
    for i in top4:
      im = get_image(i)
      print(im["label"], im["timestamp"])
    
    n=2
    plt.figure(figsize=(8, 8))
    for i, (image1, image2) in enumerate(zip(top4[:2], top4[2:])):
        ax = plt.subplot(2, n, i + 1)
        image1 = get_image(image1)["image"]
        image2 = get_image(image2)["image"]
    
        plt.imshow(image1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(image2)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    


demo = gr.Interface(fn=greet, 
                        inputs=gr.File(label='Upload image'),
                        outputs="text")
demo.launch()