bgaspra commited on
Commit
6a4d9dc
·
verified ·
1 Parent(s): ea5686d

First Commit

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
+ from tensorflow.keras.preprocessing import image
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import joblib
11
+ from PIL import UnidentifiedImageError, Image
12
+ import gradio as gr
13
+ import matplotlib.pyplot as plt
14
+
15
+ # Ensure TensorFlow uses GPU
16
+ print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
17
+ assert len(tf.config.list_physical_devices('GPU')) > 0, "No GPU available!"
18
+
19
+ # Load the dataset
20
+ dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
21
+
22
+ # Take a subset of the dataset
23
+ subset_size = 10000
24
+ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
25
+
26
+ # Directory to save images
27
+ image_dir = 'civitai_images'
28
+ os.makedirs(image_dir, exist_ok=True)
29
+
30
+ # Load the ResNet50 model pretrained on ImageNet
31
+ with tf.device('/GPU:0'):
32
+ model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
33
+
34
+ # Function to extract features
35
+ def extract_features(img_path, model):
36
+ img = image.load_img(img_path, target_size=(224, 224))
37
+ img_array = image.img_to_array(img)
38
+ img_array = np.expand_dims(img_array, axis=0)
39
+ img_array = preprocess_input(img_array)
40
+ with tf.device('/GPU:0'):
41
+ features = model.predict(img_array)
42
+ return features.flatten()
43
+
44
+ # Extract features for a sample of images
45
+ features = []
46
+ image_paths = []
47
+ model_names = []
48
+
49
+ for sample in tqdm(dataset_subset):
50
+ img_url = sample['url'] # Adjust based on the correct column name
51
+ model_name = sample['Model'] # Adjust based on the correct column name
52
+ img_path = os.path.join(image_dir, os.path.basename(img_url))
53
+
54
+ # Download the image
55
+ try:
56
+ response = requests.get(img_url)
57
+ response.raise_for_status() # Check if the download was successful
58
+
59
+ if 'image' not in response.headers['Content-Type']:
60
+ raise ValueError("URL does not contain an image")
61
+
62
+ with open(img_path, 'wb') as f:
63
+ f.write(response.content)
64
+
65
+ # Extract features
66
+ try:
67
+ img_features = extract_features(img_path, model)
68
+ features.append(img_features)
69
+ image_paths.append(img_path)
70
+ model_names.append(model_name)
71
+ except UnidentifiedImageError:
72
+ print(f"UnidentifiedImageError: Skipping file {img_path}")
73
+ os.remove(img_path)
74
+
75
+ except requests.exceptions.RequestException as e:
76
+ print(f"RequestException: Failed to download {img_url} - {e}")
77
+
78
+ # Convert features to numpy array
79
+ features = np.array(features)
80
+
81
+ # Build the NearestNeighbors model
82
+ nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(features)
83
+
84
+ # Save the model and features
85
+ joblib.dump(nbrs, 'nearest_neighbors_model.pkl')
86
+ np.save('image_features.npy', features)
87
+ np.save('image_paths.npy', image_paths)
88
+ np.save('model_names.npy', model_names)
89
+
90
+ # Load the NearestNeighbors model and features
91
+ nbrs = joblib.load('nearest_neighbors_model.pkl')
92
+ features = np.load('image_features.npy')
93
+ image_paths = np.load('image_paths.npy', allow_pickle=True)
94
+ model_names = np.load('model_names.npy', allow_pickle=True)
95
+
96
+ # Function to get recommendations
97
+ def get_recommendations(img_path, model, nbrs, image_paths, model_names, n_neighbors=5):
98
+ img_features = extract_features(img_path, model)
99
+ distances, indices = nbrs.kneighbors([img_features])
100
+
101
+ recommended_images = [image_paths[idx] for idx in indices.flatten()]
102
+ recommended_model_names = [model_names[idx] for idx in indices.flatten()]
103
+ recommended_distances = distances.flatten()
104
+
105
+ return recommended_images, recommended_model_names, recommended_distances
106
+
107
+ def recommend(image):
108
+ # Save uploaded image to a path
109
+ image_path = "uploaded_image.jpg"
110
+ image.save(image_path)
111
+
112
+ recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names)
113
+ result = list(zip(recommended_images, recommended_model_names, recommended_distances))
114
+
115
+ # Display images with matplotlib
116
+ display_images(recommended_images, recommended_model_names, recommended_distances)
117
+
118
+ return result
119
+
120
+ def display_images(image_paths, model_names, distances):
121
+ plt.figure(figsize=(20, 10))
122
+ for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)):
123
+ img = Image.open(img_path)
124
+ plt.subplot(1, len(image_paths), i+1)
125
+ plt.imshow(img)
126
+ plt.title(f'{model_name}\nDistance: {distance:.2f}', fontsize=12)
127
+ plt.axis('off')
128
+ plt.show()
129
+
130
+ # Gradio interface
131
+ interface = gr.Interface(
132
+ fn=recommend,
133
+ inputs=gr.inputs.Image(type="pil"),
134
+ outputs="text", # Outputs the list of recommended images, models, and distances
135
+ title="Image Recommendation System",
136
+ description="Upload an image and get 5 recommended similar images with model names and distances."
137
+ )
138
+
139
+ interface.launch()