File size: 3,978 Bytes
6a4d9dc
 
 
 
 
 
 
 
 
 
61e280c
6a4d9dc
 
 
 
 
64e976b
6a4d9dc
 
 
 
 
 
 
61e280c
6a4d9dc
 
 
 
 
 
 
61e280c
6a4d9dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb34ee
 
 
 
6a4d9dc
 
 
 
 
 
 
d024e9a
6a4d9dc
9cb34ee
 
d024e9a
dd78c52
d024e9a
6a4d9dc
9cb34ee
6a4d9dc
0fcbd10
acc7002
9cb34ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import requests
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from sklearn.neighbors import NearestNeighbors
import joblib
from PIL import UnidentifiedImageError, Image
import gradio as gr

# Load the dataset
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")

# Take a subset of the dataset
subset_size = 2700  # Reduce the subset size for faster execution in Spaces
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))

# Directory to save images
image_dir = 'civitai_images'
os.makedirs(image_dir, exist_ok=True)

# Load the ResNet50 model pretrained on ImageNet
model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

# Function to extract features
def extract_features(img_path, model):
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)
    features = model.predict(img_array)
    return features.flatten()

# Extract features for a sample of images
features = []
image_paths = []
model_names = []

for sample in tqdm(dataset_subset):
    img_url = sample['url']  # Adjust based on the correct column name
    model_name = sample['Model']  # Adjust based on the correct column name
    img_path = os.path.join(image_dir, os.path.basename(img_url))

    # Download the image
    try:
        response = requests.get(img_url)
        response.raise_for_status()  # Check if the download was successful

        if 'image' not in response.headers['Content-Type']:
            raise ValueError("URL does not contain an image")

        with open(img_path, 'wb') as f:
            f.write(response.content)

        # Extract features
        try:
            img_features = extract_features(img_path, model)
            features.append(img_features)
            image_paths.append(img_path)
            model_names.append(model_name)
        except UnidentifiedImageError:
            print(f"UnidentifiedImageError: Skipping file {img_path}")
            os.remove(img_path)

    except requests.exceptions.RequestException as e:
        print(f"RequestException: Failed to download {img_url} - {e}")

# Convert features to numpy array
features = np.array(features)

# Build the NearestNeighbors model
nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(features)

# Save the model and features
joblib.dump(nbrs, 'nearest_neighbors_model.pkl')
np.save('image_features.npy', features)
np.save('image_paths.npy', image_paths)
np.save('model_names.npy', model_names)

# Load the NearestNeighbors model and features
nbrs = joblib.load('nearest_neighbors_model.pkl')
features = np.load('image_features.npy')
image_paths = np.load('image_paths.npy', allow_pickle=True)
model_names = np.load('model_names.npy', allow_pickle=True)

# Function to get recommendations
def get_recommendations(img, n_neighbors=5):
    img_path = "temp_input_image.png"
    img.save(img_path)
    
    img_features = extract_features(img_path, model)
    distances, indices = nbrs.kneighbors([img_features])

    recommended_images = [image_paths[idx] for idx in indices.flatten()]
    recommended_model_names = [model_names[idx] for idx in indices.flatten()]
    recommended_distances = distances.flatten()

    return [(Image.open(img_path), f'{name}, Distance: {dist:.2f}') for img_path, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)]

# Gradio interface
interface = gr.Interface(
    fn=get_recommendations,
    inputs=gr.Image(type="pil"),
    outputs=gr.Gallery(label="Recommended Images"),
    title="Image Recommendation System",
    description="Upload an image and get similar images with their model names and distances."
)

if __name__ == "__main__":
    interface.launch()