Spaces:
Runtime error
Runtime error
Update similarity_inference.py
Browse files- similarity_inference.py +82 -53
similarity_inference.py
CHANGED
@@ -1,54 +1,83 @@
|
|
1 |
-
from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
|
2 |
-
import datasets as ds
|
3 |
-
import re
|
4 |
-
import torchvision.transforms as T
|
5 |
-
from transformers import AutoModel, AutoFeatureExtractor
|
6 |
-
import torch
|
7 |
-
import random
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
return match_dict
|
|
|
1 |
+
from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
|
2 |
+
import datasets as ds
|
3 |
+
import re
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from transformers import AutoModel, AutoFeatureExtractor
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
def similarity_inference(directory):
|
13 |
+
|
14 |
+
# Get color values for each component
|
15 |
+
color_dict = {}
|
16 |
+
for each_image in os.listdir(directory):
|
17 |
+
image_path = os.path.join(directory, each_image)
|
18 |
+
with Image.open(image_path) as img:
|
19 |
+
width, height = img.size
|
20 |
+
# add 50 random color values to color list
|
21 |
+
colors = []
|
22 |
+
for i in range(100):
|
23 |
+
# choose random pixel
|
24 |
+
random_x = random.randint(0, width - 1)
|
25 |
+
random_y = random.randint(0, height - 1)
|
26 |
+
random_pixel = img.getpixel((random_x, random_y))
|
27 |
+
# if pixel is not white
|
28 |
+
if random_pixel != (255, 255, 255):
|
29 |
+
colors.append(random_pixel)
|
30 |
+
colors_array = np.array(colors)
|
31 |
+
average_color_value = tuple(np.mean(colors_array, axis=0).astype(int))
|
32 |
+
color_dict[each_image] = average_color_value
|
33 |
+
|
34 |
+
convert_images_to_grayscale(directory)
|
35 |
+
crop_center_largest_contour(directory)
|
36 |
+
|
37 |
+
# define processing variables needed for embedding calculation
|
38 |
+
root_directory = "data/" #"C:/Users/josie/OneDrive - Chalmers/Documents/Speckle hackathon/data/"
|
39 |
+
model_ckpt = "nateraw/vit-base-beans" ## FIND DIFFERENT MODEL
|
40 |
+
candidate_subset_emb = ds.load_dataset("canadianjosieharrison/2024hackathonembeddingdb")['train']
|
41 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
|
42 |
+
model = AutoModel.from_pretrained(model_ckpt)
|
43 |
+
transformation_chain = T.Compose(
|
44 |
+
[
|
45 |
+
# We first resize the input image to 256x256 and then we take center crop.
|
46 |
+
T.Resize(int((256 / 224) * extractor.size["height"])),
|
47 |
+
T.CenterCrop(extractor.size["height"]),
|
48 |
+
T.ToTensor(),
|
49 |
+
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
|
50 |
+
])
|
51 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
+
pt_directory = root_directory + "embedding_db.pt" #"materials/embedding_db.pt"
|
53 |
+
all_candidate_embeddings = torch.load(pt_directory, map_location=device, weights_only=True)
|
54 |
+
candidate_ids = []
|
55 |
+
for id in range(len(candidate_subset_emb)):
|
56 |
+
# Create a unique indentifier.
|
57 |
+
entry = str(id) + "_" + str(random.random()).split('.')[1]
|
58 |
+
candidate_ids.append(entry)
|
59 |
+
|
60 |
+
# load all components
|
61 |
+
test_ds = ds.load_dataset("imagefolder", data_dir=directory)
|
62 |
+
label_filenames = ds.load_dataset("imagefolder", data_dir=directory).cast_column("image", ds.Image(decode=False))
|
63 |
+
|
64 |
+
# loop through each component and return top 3 most similar
|
65 |
+
match_dict = {"ceiling": [],
|
66 |
+
"floor": [],
|
67 |
+
"wall": []}
|
68 |
+
for i, each_component in enumerate(test_ds['train']):
|
69 |
+
query_image = each_component["image"]
|
70 |
+
component_label = label_filenames['train'][i]['image']['path'].split('_')[-1].split("\\")[-1]
|
71 |
+
rgb_color = color_dict[component_label]
|
72 |
+
match = re.search(r"([a-zA-Z]+)(\d*)\.png", component_label)
|
73 |
+
component_label = match.group(1)
|
74 |
+
segment_id = match.group(2)
|
75 |
+
sim_ids = fetch_similar(query_image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids)
|
76 |
+
for each_match in sim_ids:
|
77 |
+
component_texture_id = str(segment_id) + "-" + str(each_match)
|
78 |
+
texture_filename = candidate_subset_emb[each_match]['filenames']
|
79 |
+
image_url = f'https://cdn.polyhaven.com/asset_img/thumbs/{texture_filename}?width=256&height=256'
|
80 |
+
temp_dict = {"id": component_texture_id, "thumbnail": image_url, "name": texture_filename, "color": str(rgb_color)}
|
81 |
+
match_dict[component_label].append(temp_dict)
|
82 |
+
|
83 |
return match_dict
|