|
import os |
|
import requests |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input |
|
from tensorflow.keras.preprocessing import image |
|
from sklearn.neighbors import NearestNeighbors |
|
import joblib |
|
from PIL import UnidentifiedImageError, Image |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) |
|
assert len(tf.config.list_physical_devices('GPU')) > 0, "No GPU available!" |
|
|
|
|
|
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k") |
|
|
|
|
|
subset_size = 10000 |
|
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size)) |
|
|
|
|
|
image_dir = 'civitai_images' |
|
os.makedirs(image_dir, exist_ok=True) |
|
|
|
|
|
with tf.device('/GPU:0'): |
|
model = ResNet50(weights='imagenet', include_top=False, pooling='avg') |
|
|
|
|
|
def extract_features(img_path, model): |
|
img = image.load_img(img_path, target_size=(224, 224)) |
|
img_array = image.img_to_array(img) |
|
img_array = np.expand_dims(img_array, axis=0) |
|
img_array = preprocess_input(img_array) |
|
with tf.device('/GPU:0'): |
|
features = model.predict(img_array) |
|
return features.flatten() |
|
|
|
|
|
features = [] |
|
image_paths = [] |
|
model_names = [] |
|
|
|
for sample in tqdm(dataset_subset): |
|
img_url = sample['url'] |
|
model_name = sample['Model'] |
|
img_path = os.path.join(image_dir, os.path.basename(img_url)) |
|
|
|
|
|
try: |
|
response = requests.get(img_url) |
|
response.raise_for_status() |
|
|
|
if 'image' not in response.headers['Content-Type']: |
|
raise ValueError("URL does not contain an image") |
|
|
|
with open(img_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
try: |
|
img_features = extract_features(img_path, model) |
|
features.append(img_features) |
|
image_paths.append(img_path) |
|
model_names.append(model_name) |
|
except UnidentifiedImageError: |
|
print(f"UnidentifiedImageError: Skipping file {img_path}") |
|
os.remove(img_path) |
|
|
|
except requests.exceptions.RequestException as e: |
|
print(f"RequestException: Failed to download {img_url} - {e}") |
|
|
|
|
|
features = np.array(features) |
|
|
|
|
|
nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(features) |
|
|
|
|
|
joblib.dump(nbrs, 'nearest_neighbors_model.pkl') |
|
np.save('image_features.npy', features) |
|
np.save('image_paths.npy', image_paths) |
|
np.save('model_names.npy', model_names) |
|
|
|
|
|
nbrs = joblib.load('nearest_neighbors_model.pkl') |
|
features = np.load('image_features.npy') |
|
image_paths = np.load('image_paths.npy', allow_pickle=True) |
|
model_names = np.load('model_names.npy', allow_pickle=True) |
|
|
|
|
|
def get_recommendations(img_path, model, nbrs, image_paths, model_names, n_neighbors=5): |
|
img_features = extract_features(img_path, model) |
|
distances, indices = nbrs.kneighbors([img_features]) |
|
|
|
recommended_images = [image_paths[idx] for idx in indices.flatten()] |
|
recommended_model_names = [model_names[idx] for idx in indices.flatten()] |
|
recommended_distances = distances.flatten() |
|
|
|
return recommended_images, recommended_model_names, recommended_distances |
|
|
|
def recommend(image): |
|
|
|
image_path = "uploaded_image.jpg" |
|
image.save(image_path) |
|
|
|
recommended_images, recommended_model_names, recommended_distances = get_recommendations(image_path, model, nbrs, image_paths, model_names) |
|
result = list(zip(recommended_images, recommended_model_names, recommended_distances)) |
|
|
|
|
|
display_images(recommended_images, recommended_model_names, recommended_distances) |
|
|
|
return result |
|
|
|
def display_images(image_paths, model_names, distances): |
|
plt.figure(figsize=(20, 10)) |
|
for i, (img_path, model_name, distance) in enumerate(zip(image_paths, model_names, distances)): |
|
img = Image.open(img_path) |
|
plt.subplot(1, len(image_paths), i+1) |
|
plt.imshow(img) |
|
plt.title(f'{model_name}\nDistance: {distance:.2f}', fontsize=12) |
|
plt.axis('off') |
|
plt.show() |
|
|
|
|
|
interface = gr.Interface( |
|
fn=recommend, |
|
inputs=gr.inputs.Image(type="pil"), |
|
outputs="text", |
|
title="Image Recommendation System", |
|
description="Upload an image and get 5 recommended similar images with model names and distances." |
|
) |
|
|
|
interface.launch() |
|
|