bgaspra commited on
Commit
b242893
·
verified ·
1 Parent(s): e6d2088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -33
app.py CHANGED
@@ -3,7 +3,6 @@ import requests
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
6
- import tensorflow as tf
7
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
  from tensorflow.keras.preprocessing import image
9
  from sklearn.neighbors import NearestNeighbors
@@ -23,23 +22,8 @@ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
23
  image_dir = 'civitai_images'
24
  os.makedirs(image_dir, exist_ok=True)
25
 
26
- # Try to use GPU, fall back to CPU if not available
27
- try:
28
- gpus = tf.config.list_physical_devices('GPU')
29
- if gpus:
30
- tf.config.experimental.set_memory_growth(gpus[0], True)
31
- device = '/GPU:0'
32
- print("Using GPU")
33
- else:
34
- raise RuntimeError("No GPU found")
35
- except RuntimeError as e:
36
- print(e)
37
- device = '/CPU:0'
38
- print("Using CPU")
39
-
40
  # Load the ResNet50 model pretrained on ImageNet
41
- with tf.device(device):
42
- model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
43
 
44
  # Function to extract features
45
  def extract_features(img_path, model):
@@ -47,8 +31,7 @@ def extract_features(img_path, model):
47
  img_array = image.img_to_array(img)
48
  img_array = np.expand_dims(img_array, axis=0)
49
  img_array = preprocess_input(img_array)
50
- with tf.device(device):
51
- features = model.predict(img_array)
52
  return features.flatten()
53
 
54
  # Extract features for a sample of images
@@ -122,15 +105,13 @@ def recommend(image):
122
  recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
123
  result = list(zip(recommended_images, recommended_model_names, recommended_distances))
124
 
125
- # Display images with matplotlib
126
- display_images(recommended_images, recommended_model_names, recommended_distances)
127
-
128
  # Prepare HTML output for Gradio
129
  html_output = ""
130
  for img_path, model_name, distance in zip(recommended_images, recommended_model_names, recommended_distances):
 
131
  html_output += f"""
132
  <div style='display:inline-block; text-align:center; margin:10px;'>
133
- <img src='file://{img_path}' style='width:200px; height:200px;'><br>
134
  <b>Model Name:</b> {model_name}<br>
135
  <b>Distance:</b> {distance:.2f}<br>
136
  </div>
@@ -138,16 +119,6 @@ def recommend(image):
138
 
139
  return html_output
140
 
141
- def display_images(image_paths, model_names, distances):
142
- plt.figure(figsize=(20, 10))
143
- for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)):
144
- img = Image.open(img_path)
145
- plt.subplot(1, len(image_paths), i+1)
146
- plt.imshow(img)
147
- plt.title(f'{model_name}\nDistance: {distance:.2f}', fontsize=12)
148
- plt.axis('off')
149
- plt.show()
150
-
151
  interface = gr.Interface(
152
  fn=recommend,
153
  inputs=gr.Image(type="pil"),
 
3
  from tqdm import tqdm
4
  from datasets import load_dataset
5
  import numpy as np
 
6
  from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
7
  from tensorflow.keras.preprocessing import image
8
  from sklearn.neighbors import NearestNeighbors
 
22
  image_dir = 'civitai_images'
23
  os.makedirs(image_dir, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Load the ResNet50 model pretrained on ImageNet
26
+ model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
 
27
 
28
  # Function to extract features
29
  def extract_features(img_path, model):
 
31
  img_array = image.img_to_array(img)
32
  img_array = np.expand_dims(img_array, axis=0)
33
  img_array = preprocess_input(img_array)
34
+ features = model.predict(img_array)
 
35
  return features.flatten()
36
 
37
  # Extract features for a sample of images
 
105
  recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
106
  result = list(zip(recommended_images, recommended_model_names, recommended_distances))
107
 
 
 
 
108
  # Prepare HTML output for Gradio
109
  html_output = ""
110
  for img_path, model_name, distance in zip(recommended_images, recommended_model_names, recommended_distances):
111
+ img_path = img_path.replace('\\', '/')
112
  html_output += f"""
113
  <div style='display:inline-block; text-align:center; margin:10px;'>
114
+ <img src='file/{img_path}' style='width:200px; height:200px;'><br>
115
  <b>Model Name:</b> {model_name}<br>
116
  <b>Distance:</b> {distance:.2f}<br>
117
  </div>
 
119
 
120
  return html_output
121
 
 
 
 
 
 
 
 
 
 
 
122
  interface = gr.Interface(
123
  fn=recommend,
124
  inputs=gr.Image(type="pil"),