Spaces:
Running
Running
File size: 5,037 Bytes
ed77618 66aafb1 ed77618 66aafb1 ed77618 66aafb1 ed77618 dfe9653 759ba4b dfe9653 ed77618 dfe9653 ed77618 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Generate and save the embeddings of a pre-defined list of icons.
Compare them with keywords embeddings to find most relevant icons.
"""
import os
import pathlib
import sys
from typing import List, Tuple
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BertTokenizer, BertModel
sys.path.append('..')
sys.path.append('../..')
from global_config import GlobalConfig
tokenizer = BertTokenizer.from_pretrained(GlobalConfig.TINY_BERT_MODEL)
model = BertModel.from_pretrained(GlobalConfig.TINY_BERT_MODEL)
def get_icons_list() -> List[str]:
"""
Get a list of available icons.
:return: The icons file names.
"""
items = pathlib.Path('../' + GlobalConfig.ICONS_DIR).glob('*.png')
items = [
os.path.basename(str(item)).removesuffix('.png') for item in items
]
return items
def get_embeddings(texts) -> np.ndarray:
"""
Generate embeddings for a list of texts using a pre-trained language model.
:param texts: A string or a list of strings to be converted into embeddings.
:type texts: Union[str, List[str]]
:return: A NumPy array containing the embeddings for the input texts.
:rtype: numpy.ndarray
:raises ValueError: If the input is not a string or a list of strings, or if any element
in the list is not a string.
Example usage:
>>> keyword = 'neural network'
>>> file_names = ['neural_network_icon.png', 'data_analysis_icon.png', 'machine_learning.png']
>>> keyword_embeddings = get_embeddings(keyword)
>>> file_name_embeddings = get_embeddings(file_names)
"""
inputs = tokenizer(texts, return_tensors='pt', padding=True, max_length=128, truncation=True)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
def save_icons_embeddings():
"""
Generate and save the embeddings for the icon file names.
"""
file_names = get_icons_list()
print(f'{len(file_names)} icon files available...')
file_name_embeddings = get_embeddings(file_names)
print(f'file_name_embeddings.shape: {file_name_embeddings.shape}')
# Save embeddings to a file
np.save(GlobalConfig.EMBEDDINGS_FILE_NAME, file_name_embeddings)
np.save(GlobalConfig.ICONS_FILE_NAME, file_names) # Save file names for reference
def load_saved_embeddings() -> Tuple[np.ndarray, np.ndarray]:
"""
Load precomputed embeddings and icons file names.
:return: The embeddings and the icon file names.
"""
file_name_embeddings = np.load(GlobalConfig.EMBEDDINGS_FILE_NAME)
file_names = np.load(GlobalConfig.ICONS_FILE_NAME)
return file_name_embeddings, file_names
def find_icons(keywords: List[str]) -> List[str]:
"""
Find relevant icon file names for a list of keywords.
:param keywords: The list of one or more keywords.
:return: A list of the file names relevant for each keyword.
"""
keyword_embeddings = get_embeddings(keywords)
file_name_embeddings, file_names = load_saved_embeddings()
# Compute similarity
similarities = cosine_similarity(keyword_embeddings, file_name_embeddings)
icon_files = file_names[np.argmax(similarities, axis=-1)]
return icon_files
def main():
"""
Example usage.
"""
# Run this again if icons are to be added/removed
save_icons_embeddings()
keywords = [
'deep learning',
'',
'recycling',
'handshake',
'Ferry',
'rain drop',
'speech bubble',
'mental resilience',
'turmeric',
'Art',
'price tag',
'Oxygen',
'oxygen',
'Social Connection',
'Accomplishment',
'Python',
'XML',
'Handshake',
]
icon_files = find_icons(keywords)
print(
f'The relevant icon files are:\n'
f'{list(zip(keywords, icon_files))}'
)
# BERT tiny:
# [('deep learning', 'deep-learning'), ('', '123'), ('recycling', 'refinery'),
# ('handshake', 'dash-circle'), ('Ferry', 'cart'), ('rain drop', 'bucket'),
# ('speech bubble', 'globe'), ('mental resilience', 'exclamation-triangle'),
# ('turmeric', 'kebab'), ('Art', 'display'), ('price tag', 'bug-fill'),
# ('Oxygen', 'radioactive')]
# BERT mini
# [('deep learning', 'deep-learning'), ('', 'compass'), ('recycling', 'tools'),
# ('handshake', 'bandaid'), ('Ferry', 'cart'), ('rain drop', 'trash'),
# ('speech bubble', 'image'), ('mental resilience', 'recycle'), ('turmeric', 'linkedin'),
# ('Art', 'book'), ('price tag', 'card-image'), ('Oxygen', 'radioactive')]
# BERT small
# [('deep learning', 'deep-learning'), ('', 'gem'), ('recycling', 'tools'),
# ('handshake', 'handbag'), ('Ferry', 'truck'), ('rain drop', 'bucket'),
# ('speech bubble', 'strategy'), ('mental resilience', 'deep-learning'),
# ('turmeric', 'flower'),
# ('Art', 'book'), ('price tag', 'hotdog'), ('Oxygen', 'radioactive')]
if __name__ == '__main__':
main()
|