bgaspra commited on
Commit
9cb34ee
·
verified ·
1 Parent(s): 3358a37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -34
app.py CHANGED
@@ -8,13 +8,14 @@ from tensorflow.keras.preprocessing import image
8
  from sklearn.neighbors import NearestNeighbors
9
  import joblib
10
  from PIL import UnidentifiedImageError, Image
 
11
  import gradio as gr
12
 
13
  # Load the dataset
14
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
15
 
16
  # Take a subset of the dataset
17
- subset_size = 50
18
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
19
 
20
  # Directory to save images
@@ -86,7 +87,10 @@ image_paths = np.load('image_paths.npy', allow_pickle=True)
86
  model_names = np.load('model_names.npy', allow_pickle=True)
87
 
88
  # Function to get recommendations
89
- def get_recommendations(img_path, model, nbrs, image_paths, model_names, n_neighbors=5):
 
 
 
90
  img_features = extract_features(img_path, model)
91
  distances, indices = nbrs.kneighbors([img_features])
92
 
@@ -94,40 +98,20 @@ def get_recommendations(img_path, model, nbrs, image_paths, model_names, n_neigh
94
  recommended_model_names = [model_names[idx] for idx in indices.flatten()]
95
  recommended_distances = distances.flatten()
96
 
97
- return recommended_images, recommended_model_names, recommended_distances
98
 
99
- def get_recommendations_and_display(img_path):
100
- recommended_images, recommended_model_names, recommended_distances = get_recommendations(img_path, model, nbrs, image_paths, model_names)
101
-
102
- results = []
103
- for i in range(len(recommended_images)):
104
- result = {
105
- "Image": Image.open(recommended_images[i]),
106
- "Model Name": recommended_model_names[i],
107
- "Distance": recommended_distances[i]
108
- }
109
- results.append(result)
110
- return results
111
-
112
- # Define Gradio interface
113
- def gradio_interface(input_image):
114
- input_image.save("input_image.jpg") # Save the input image
115
- recommendations = get_recommendations_and_display("input_image.jpg")
116
- outputs = []
117
- for i, rec in enumerate(recommendations):
118
- outputs.append((rec["Image"], f"{rec['Model Name']} (Distance: {rec['Distance']:.2f})"))
119
- return outputs
120
-
121
- # Create the Gradio app
122
- iface = gr.Interface(
123
- fn=gradio_interface,
124
- inputs=gr.Image(type="pil"),
125
- outputs=[gr.Image(label=f"Recommendation {i+1} Image") for i in range(5)] +
126
- [gr.Textbox(label=f"Recommendation {i+1} Details") for i in range(5)],
127
  title="Image Recommendation System",
128
- description="Upload an image to get recommendations based on the image"
129
  )
130
 
131
- # Launch the Gradio app
132
  if __name__ == "__main__":
133
- iface.launch()
 
8
  from sklearn.neighbors import NearestNeighbors
9
  import joblib
10
  from PIL import UnidentifiedImageError, Image
11
+ import matplotlib.pyplot as plt
12
  import gradio as gr
13
 
14
  # Load the dataset
15
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
16
 
17
  # Take a subset of the dataset
18
+ subset_size = 100 # Reduce the subset size for faster execution in Spaces
19
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
20
 
21
  # Directory to save images
 
87
  model_names = np.load('model_names.npy', allow_pickle=True)
88
 
89
  # Function to get recommendations
90
+ def get_recommendations(img, n_neighbors=5):
91
+ img_path = "temp_input_image.png"
92
+ img.save(img_path)
93
+
94
  img_features = extract_features(img_path, model)
95
  distances, indices = nbrs.kneighbors([img_features])
96
 
 
98
  recommended_model_names = [model_names[idx] for idx in indices.flatten()]
99
  recommended_distances = distances.flatten()
100
 
101
+ return [Image.open(img_path) for img_path in recommended_images], recommended_model_names, recommended_distances
102
 
103
+ # Gradio interface
104
+ def display_images(input_image):
105
+ recommended_images, recommended_model_names, recommended_distances = get_recommendations(input_image)
106
+ return [(img, f'{name}, Distance: {dist:.2f}') for img, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)]
107
+
108
+ interface = gr.Interface(
109
+ fn=display_images,
110
+ inputs=gr.inputs.Image(type="pil"),
111
+ outputs=gr.outputs.Carousel(label="Recommended Images", item_shape=(256, 256)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  title="Image Recommendation System",
113
+ description="Upload an image and get similar images with their model names and distances."
114
  )
115
 
 
116
  if __name__ == "__main__":
117
+ interface.launch()