bgaspra commited on
Commit
61e280c
·
verified ·
1 Parent(s): 2d4593d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -40
app.py CHANGED
@@ -3,14 +3,13 @@ import requests
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
6
- import tensorflow as tf
7
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
  from tensorflow.keras.preprocessing import image
9
  from sklearn.neighbors import NearestNeighbors
10
  import joblib
11
  from PIL import UnidentifiedImageError, Image
12
- import gradio as gr
13
  import matplotlib.pyplot as plt
 
14
 
15
  # Load the dataset
16
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
@@ -23,23 +22,8 @@ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
23
  image_dir = 'civitai_images'
24
  os.makedirs(image_dir, exist_ok=True)
25
 
26
- # Try to use GPU, fall back to CPU if not available
27
- try:
28
- gpus = tf.config.list_physical_devices('GPU')
29
- if gpus:
30
- tf.config.experimental.set_memory_growth(gpus[0], True)
31
- device = '/GPU:0'
32
- print("Using GPU")
33
- else:
34
- raise RuntimeError("No GPU found")
35
- except RuntimeError as e:
36
- print(e)
37
- device = '/CPU:0'
38
- print("Using CPU")
39
-
40
  # Load the ResNet50 model pretrained on ImageNet
41
- with tf.device(device):
42
- model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
43
 
44
  # Function to extract features
45
  def extract_features(img_path, model):
@@ -47,8 +31,7 @@ def extract_features(img_path, model):
47
  img_array = image.img_to_array(img)
48
  img_array = np.expand_dims(img_array, axis=0)
49
  img_array = preprocess_input(img_array)
50
- with tf.device(device):
51
- features = model.predict(img_array)
52
  return features.flatten()
53
 
54
  # Extract features for a sample of images
@@ -114,19 +97,6 @@ def get_recommendations(img_path, model, nbrs, image_paths, model_names, n_neigh
114
 
115
  return recommended_images, recommended_model_names, recommended_distances
116
 
117
- def recommend(image):
118
- # Save uploaded image to a path
119
- image_path = "uploaded_image.jpg"
120
- image.save(image_path)
121
-
122
- recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
123
- result = list(zip(recommended_images, recommended_model_names, recommended_distances))
124
-
125
- # Display images with matplotlib
126
- display_images(recommended_images, recommended_model_names, recommended_distances)
127
-
128
- return result
129
-
130
  def display_images(image_paths, model_names, distances):
131
  plt.figure(figsize=(20, 10))
132
  for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)):
@@ -137,12 +107,19 @@ def display_images(image_paths, model_names, distances):
137
  plt.axis('off')
138
  plt.show()
139
 
140
- interface = gr.Interface(
141
- fn=recommend,
142
- inputs=gr.Image(type="pil"), # Updated input component
143
- outputs=gr.Text(), # Updated output component
 
 
 
 
 
 
 
 
144
  title="Image Recommendation System",
145
- description="Upload an image and get 5 recommended similar images with model names and distances."
146
  )
147
-
148
- interface.launch()
 
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
 
6
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
7
  from tensorflow.keras.preprocessing import image
8
  from sklearn.neighbors import NearestNeighbors
9
  import joblib
10
  from PIL import UnidentifiedImageError, Image
 
11
  import matplotlib.pyplot as plt
12
+ import gradio as gr
13
 
14
  # Load the dataset
15
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
 
22
  image_dir = 'civitai_images'
23
  os.makedirs(image_dir, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Load the ResNet50 model pretrained on ImageNet
26
+ model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
 
27
 
28
  # Function to extract features
29
  def extract_features(img_path, model):
 
31
  img_array = image.img_to_array(img)
32
  img_array = np.expand_dims(img_array, axis=0)
33
  img_array = preprocess_input(img_array)
34
+ features = model.predict(img_array)
 
35
  return features.flatten()
36
 
37
  # Extract features for a sample of images
 
97
 
98
  return recommended_images, recommended_model_names, recommended_distances
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def display_images(image_paths, model_names, distances):
101
  plt.figure(figsize=(20, 10))
102
  for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)):
 
107
  plt.axis('off')
108
  plt.show()
109
 
110
+ def recommend_images(img):
111
+ recommended_images, recommended_model_names, recommended_distances = get_recommendations(img, model, nbrs, image_paths, model_names)
112
+ return [Image.open(img_path) for img_path in recommended_images], [model_name for model_name in recommended_model_names], [distance for distance in recommended_distances]
113
+
114
+ iface = gr.Interface(
115
+ fn=recommend_images,
116
+ inputs=gr.Image(label="Upload an image"),
117
+ outputs=[
118
+ gr.Gallery(label="Recommended Images", show_label=False),
119
+ gr.Textbox(label="Model Names", lines=5),
120
+ gr.Textbox(label="Distances", lines=5)
121
+ ],
122
  title="Image Recommendation System",
123
+ description="Upload an image and get recommendations based on similarity."
124
  )
125
+ iface.launch()