bgaspra commited on
Commit
2d4593d
·
verified ·
1 Parent(s): b242893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
 
6
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
7
  from tensorflow.keras.preprocessing import image
8
  from sklearn.neighbors import NearestNeighbors
@@ -22,8 +23,23 @@ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
22
  image_dir = 'civitai_images'
23
  os.makedirs(image_dir, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Load the ResNet50 model pretrained on ImageNet
26
- model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
 
27
 
28
  # Function to extract features
29
  def extract_features(img_path, model):
@@ -31,7 +47,8 @@ def extract_features(img_path, model):
31
  img_array = image.img_to_array(img)
32
  img_array = np.expand_dims(img_array, axis=0)
33
  img_array = preprocess_input(img_array)
34
- features = model.predict(img_array)
 
35
  return features.flatten()
36
 
37
  # Extract features for a sample of images
@@ -105,26 +122,27 @@ def recommend(image):
105
  recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
106
  result = list(zip(recommended_images, recommended_model_names, recommended_distances))
107
 
108
- # Prepare HTML output for Gradio
109
- html_output = ""
110
- for img_path, model_name, distance in zip(recommended_images, recommended_model_names, recommended_distances):
111
- img_path = img_path.replace('\\', '/')
112
- html_output += f"""
113
- <div style='display:inline-block; text-align:center; margin:10px;'>
114
- <img src='file/{img_path}' style='width:200px; height:200px;'><br>
115
- <b>Model Name:</b> {model_name}<br>
116
- <b>Distance:</b> {distance:.2f}<br>
117
- </div>
118
- """
119
 
120
- return html_output
 
 
 
 
 
 
 
 
 
 
121
 
122
  interface = gr.Interface(
123
  fn=recommend,
124
- inputs=gr.Image(type="pil"),
125
- outputs=gr.HTML(), # Use HTML output for better formatting
126
  title="Image Recommendation System",
127
  description="Upload an image and get 5 recommended similar images with model names and distances."
128
  )
129
 
130
- interface.launch()
 
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
6
+ import tensorflow as tf
7
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
  from tensorflow.keras.preprocessing import image
9
  from sklearn.neighbors import NearestNeighbors
 
23
  image_dir = 'civitai_images'
24
  os.makedirs(image_dir, exist_ok=True)
25
 
26
+ # Try to use GPU, fall back to CPU if not available
27
+ try:
28
+ gpus = tf.config.list_physical_devices('GPU')
29
+ if gpus:
30
+ tf.config.experimental.set_memory_growth(gpus[0], True)
31
+ device = '/GPU:0'
32
+ print("Using GPU")
33
+ else:
34
+ raise RuntimeError("No GPU found")
35
+ except RuntimeError as e:
36
+ print(e)
37
+ device = '/CPU:0'
38
+ print("Using CPU")
39
+
40
  # Load the ResNet50 model pretrained on ImageNet
41
+ with tf.device(device):
42
+ model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
43
 
44
  # Function to extract features
45
  def extract_features(img_path, model):
 
47
  img_array = image.img_to_array(img)
48
  img_array = np.expand_dims(img_array, axis=0)
49
  img_array = preprocess_input(img_array)
50
+ with tf.device(device):
51
+ features = model.predict(img_array)
52
  return features.flatten()
53
 
54
  # Extract features for a sample of images
 
122
  recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
123
  result = list(zip(recommended_images, recommended_model_names, recommended_distances))
124
 
125
+ # Display images with matplotlib
126
+ display_images(recommended_images, recommended_model_names, recommended_distances)
 
 
 
 
 
 
 
 
 
127
 
128
+ return result
129
+
130
+ def display_images(image_paths, model_names, distances):
131
+ plt.figure(figsize=(20, 10))
132
+ for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)):
133
+ img = Image.open(img_path)
134
+ plt.subplot(1, len(image_paths), i+1)
135
+ plt.imshow(img)
136
+ plt.title(f'{model_name}\nDistance: {distance:.2f}', fontsize=12)
137
+ plt.axis('off')
138
+ plt.show()
139
 
140
  interface = gr.Interface(
141
  fn=recommend,
142
+ inputs=gr.Image(type="pil"), # Updated input component
143
+ outputs=gr.Text(), # Updated output component
144
  title="Image Recommendation System",
145
  description="Upload an image and get 5 recommended similar images with model names and distances."
146
  )
147
 
148
+ interface.launch()