canadianjosieharrison commited on
Commit
5d8f225
·
verified ·
1 Parent(s): e5685f4

Update similarity_inference.py

Browse files
Files changed (1) hide show
  1. similarity_inference.py +82 -53
similarity_inference.py CHANGED
@@ -1,54 +1,83 @@
1
- from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
2
- import datasets as ds
3
- import re
4
- import torchvision.transforms as T
5
- from transformers import AutoModel, AutoFeatureExtractor
6
- import torch
7
- import random
8
-
9
- def similarity_inference(directory):
10
- convert_images_to_grayscale(directory)
11
- crop_center_largest_contour(directory)
12
-
13
- # define processing variables needed for embedding calculation
14
- root_directory = "data/" #"C:/Users/josie/OneDrive - Chalmers/Documents/Speckle hackathon/data/"
15
- model_ckpt = "nateraw/vit-base-beans" ## FIND DIFFERENT MODEL
16
- candidate_subset_emb = ds.load_dataset("canadianjosieharrison/2024hackathonembeddingdb")['train']
17
- extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
18
- model = AutoModel.from_pretrained(model_ckpt)
19
- transformation_chain = T.Compose(
20
- [
21
- # We first resize the input image to 256x256 and then we take center crop.
22
- T.Resize(int((256 / 224) * extractor.size["height"])),
23
- T.CenterCrop(extractor.size["height"]),
24
- T.ToTensor(),
25
- T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
26
- ])
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- pt_directory = root_directory + "embedding_db.pt" #"materials/embedding_db.pt"
29
- all_candidate_embeddings = torch.load(pt_directory, map_location=device, weights_only=True)
30
- candidate_ids = []
31
- for id in range(len(candidate_subset_emb)):
32
- # Create a unique indentifier.
33
- entry = str(id) + "_" + str(random.random()).split('.')[1]
34
- candidate_ids.append(entry)
35
-
36
- # load all components
37
- test_ds = ds.load_dataset("imagefolder", data_dir=directory)
38
- label_filenames = ds.load_dataset("imagefolder", data_dir=directory).cast_column("image", ds.Image(decode=False))
39
-
40
- # loop through each component and return top 3 most similar
41
- match_dict = {"ceiling": [], "floor": [], "wall": []}
42
- for i, each_component in enumerate(test_ds['train']):
43
- query_image = each_component["image"]
44
- component_label = label_filenames['train'][i]['image']['path'].split('_')[-1]
45
- print(component_label)
46
- match = re.search(r"([a-zA-Z]+)\d*\.png", component_label)
47
- component_label = match.group(1)
48
- sim_ids = fetch_similar(query_image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids)
49
- for each_match in sim_ids:
50
- texture_filename = candidate_subset_emb[each_match]['filenames']
51
- image_url = f'https://cdn.polyhaven.com/asset_img/thumbs/{texture_filename}?width=256&height=256'
52
- match_dict[component_label].append(image_url)
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return match_dict
 
1
+ from image_helpers import convert_images_to_grayscale, crop_center_largest_contour, fetch_similar
2
+ import datasets as ds
3
+ import re
4
+ import torchvision.transforms as T
5
+ from transformers import AutoModel, AutoFeatureExtractor
6
+ import torch
7
+ import random
8
+ import os
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+ def similarity_inference(directory):
13
+
14
+ # Get color values for each component
15
+ color_dict = {}
16
+ for each_image in os.listdir(directory):
17
+ image_path = os.path.join(directory, each_image)
18
+ with Image.open(image_path) as img:
19
+ width, height = img.size
20
+ # add 50 random color values to color list
21
+ colors = []
22
+ for i in range(100):
23
+ # choose random pixel
24
+ random_x = random.randint(0, width - 1)
25
+ random_y = random.randint(0, height - 1)
26
+ random_pixel = img.getpixel((random_x, random_y))
27
+ # if pixel is not white
28
+ if random_pixel != (255, 255, 255):
29
+ colors.append(random_pixel)
30
+ colors_array = np.array(colors)
31
+ average_color_value = tuple(np.mean(colors_array, axis=0).astype(int))
32
+ color_dict[each_image] = average_color_value
33
+
34
+ convert_images_to_grayscale(directory)
35
+ crop_center_largest_contour(directory)
36
+
37
+ # define processing variables needed for embedding calculation
38
+ root_directory = "data/" #"C:/Users/josie/OneDrive - Chalmers/Documents/Speckle hackathon/data/"
39
+ model_ckpt = "nateraw/vit-base-beans" ## FIND DIFFERENT MODEL
40
+ candidate_subset_emb = ds.load_dataset("canadianjosieharrison/2024hackathonembeddingdb")['train']
41
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
42
+ model = AutoModel.from_pretrained(model_ckpt)
43
+ transformation_chain = T.Compose(
44
+ [
45
+ # We first resize the input image to 256x256 and then we take center crop.
46
+ T.Resize(int((256 / 224) * extractor.size["height"])),
47
+ T.CenterCrop(extractor.size["height"]),
48
+ T.ToTensor(),
49
+ T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
50
+ ])
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ pt_directory = root_directory + "embedding_db.pt" #"materials/embedding_db.pt"
53
+ all_candidate_embeddings = torch.load(pt_directory, map_location=device, weights_only=True)
54
+ candidate_ids = []
55
+ for id in range(len(candidate_subset_emb)):
56
+ # Create a unique indentifier.
57
+ entry = str(id) + "_" + str(random.random()).split('.')[1]
58
+ candidate_ids.append(entry)
59
+
60
+ # load all components
61
+ test_ds = ds.load_dataset("imagefolder", data_dir=directory)
62
+ label_filenames = ds.load_dataset("imagefolder", data_dir=directory).cast_column("image", ds.Image(decode=False))
63
+
64
+ # loop through each component and return top 3 most similar
65
+ match_dict = {"ceiling": [],
66
+ "floor": [],
67
+ "wall": []}
68
+ for i, each_component in enumerate(test_ds['train']):
69
+ query_image = each_component["image"]
70
+ component_label = label_filenames['train'][i]['image']['path'].split('_')[-1].split("\\")[-1]
71
+ rgb_color = color_dict[component_label]
72
+ match = re.search(r"([a-zA-Z]+)(\d*)\.png", component_label)
73
+ component_label = match.group(1)
74
+ segment_id = match.group(2)
75
+ sim_ids = fetch_similar(query_image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids)
76
+ for each_match in sim_ids:
77
+ component_texture_id = str(segment_id) + "-" + str(each_match)
78
+ texture_filename = candidate_subset_emb[each_match]['filenames']
79
+ image_url = f'https://cdn.polyhaven.com/asset_img/thumbs/{texture_filename}?width=256&height=256'
80
+ temp_dict = {"id": component_texture_id, "thumbnail": image_url, "name": texture_filename, "color": str(rgb_color)}
81
+ match_dict[component_label].append(temp_dict)
82
+
83
  return match_dict