henryhyunwookim commited on
Commit
dc81f01
·
verified ·
1 Parent(s): 5bf544e

Update utils/utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +176 -153
utils/utils.py CHANGED
@@ -1,154 +1,177 @@
1
- import os
2
- import logging
3
- from datetime import datetime
4
- from pathlib import Path
5
- import pickle
6
- from tqdm import tqdm
7
- from datasets import load_dataset
8
- import chromadb
9
- import matplotlib.pyplot as plt
10
-
11
-
12
- def set_directories():
13
- curr_dir = Path(os.getcwd())
14
-
15
- data_dir = curr_dir / 'data'
16
- data_pickle_path = data_dir / 'data_set.pkl'
17
-
18
- vectordb_dir = curr_dir / 'vectore_storage'
19
- chroma_dir = vectordb_dir / 'chroma'
20
-
21
- for dir in [data_dir, vectordb_dir, chroma_dir]:
22
- if not os.path.exists(dir):
23
- os.mkdir(dir)
24
-
25
- return data_pickle_path, chroma_dir
26
-
27
-
28
- def load_data(data_pickle_path, dataset="vipulmaheshwari/GTA-Image-Captioning-Dataset"):
29
- if not os.path.exists(data_pickle_path):
30
- print(f"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.")
31
- data_set = load_dataset(dataset)
32
- with open(data_pickle_path, 'wb') as outfile:
33
- pickle.dump(data_set, outfile)
34
- else:
35
- print(f"Data set already exists in the local drive. Loading it.")
36
- with open(data_pickle_path, 'rb') as infile:
37
- data_set = pickle.load(infile)
38
-
39
- return data_set
40
-
41
-
42
- def get_embeddings(data, model):
43
- # Get the id and embedding of each data/image
44
- ids = []
45
- embeddings = []
46
- for id, image in tqdm(zip(list(range(len(data))), data)):
47
- ids.append("image "+str(id))
48
-
49
- embedding = model.encode(image)
50
- embeddings.append(embedding.tolist())
51
-
52
- return ids, embeddings
53
-
54
-
55
- def get_collection(chroma_dir, model, collection_name, data):
56
- client = chromadb.PersistentClient(path=chroma_dir.__str__())
57
- collection = client.get_or_create_collection(name=collection_name)
58
-
59
- if collection.count() != len(data):
60
- print("Adding embeddings to the collection.")
61
- ids, embeddings = get_embeddings(data, model)
62
- collection.add(
63
- ids=ids,
64
- embeddings=embeddings
65
- )
66
- else:
67
- print("Embeddings are already added to the collection.")
68
-
69
- return collection
70
-
71
-
72
- def get_result(collection, data_set, query, model, n_results=2):
73
- # Query the vector store and get results
74
- results = collection.query(
75
- query_embeddings=model.encode([query]),
76
- n_results=2
77
- )
78
-
79
- # Get the id of the most relevant image
80
- img_id = int(results['ids'][0][0].split('image ')[-1])
81
-
82
- # Get the image and its caption
83
- image = data_set['train']['image'][img_id]
84
- text = data_set['train']['text'][img_id]
85
-
86
- return image, text
87
-
88
-
89
- def show_image(image, text, query):
90
- plt.ion()
91
- plt.axis("off")
92
- plt.imshow(image)
93
- plt.show()
94
- print(f"User query: {query}")
95
- print(f"Original description: {text}\n")
96
-
97
-
98
- def get_logger():
99
- log_path = "./log/"
100
- if not os.path.exists(log_path):
101
- os.mkdir(log_path)
102
-
103
- cur_date = datetime.utcnow().strftime("%Y%m%d")
104
- log_filename = f"{log_path}{cur_date}.log"
105
-
106
- logging.basicConfig(
107
- filename=log_filename,
108
- level=logging.INFO,
109
- format="%(asctime)s %(levelname)-8s %(message)s",
110
- datefmt="%Y-%m-%d %H:%M:%S")
111
-
112
- logger = logging.getLogger(__name__)
113
-
114
- return logger
115
-
116
-
117
- def initialization(logger):
118
- print("Initializing...")
119
- logger.info("Initializing...")
120
- print("-------------------------------------------------------")
121
- logger.info("-------------------------------------------------------")
122
-
123
- print("Importing functions...")
124
- logger.info("Importing functions...")
125
- # Import module, classes, and functions
126
- from sentence_transformers import SentenceTransformer
127
- from utils.utils import set_directories, load_data, get_collection, get_result, show_image
128
-
129
- print("Set directories...")
130
- logger.info("Set directories...")
131
- # Set directories
132
- data_pickle_path, chroma_dir = set_directories()
133
-
134
- print("Loading data...")
135
- logger.info("Loading data...")
136
- # Load dataset
137
- data_set = load_data(data_pickle_path)
138
-
139
- print("Loading CLIP model...")
140
- logger.info("Loading CLIP model...")
141
- # Load CLIP model
142
- model = SentenceTransformer("sentence-transformers/clip-ViT-L-14")
143
-
144
- print("Getting vector embeddings...")
145
- logger.info("Getting vector embeddings...")
146
- # Get vector embeddings
147
- collection = get_collection(chroma_dir, model, collection_name='image_vectors', data=data_set['train']['image'])
148
-
149
- print("-------------------------------------------------------")
150
- logger.info("-------------------------------------------------------")
151
- print("Initialization completed! Ready for search.")
152
- logger.info("Initialization completed! Ready for search.")
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return collection, data_set, model, logger
 
1
+ import os
2
+ import logging
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import pickle
6
+ from tqdm import tqdm
7
+ from datasets import load_dataset
8
+ import chromadb
9
+ import matplotlib.pyplot as plt
10
+ from sentence_transformers import SentenceTransformer
11
+ import google.generativeai as genai
12
+ from dotenv import load_dotenv
13
+
14
+
15
+ def set_directories():
16
+ curr_dir = Path(os.getcwd())
17
+
18
+ data_dir = curr_dir / 'data'
19
+ data_pickle_path = data_dir / 'data_set.pkl'
20
+
21
+ vectordb_dir = curr_dir / 'vector_storage'
22
+ chroma_dir = vectordb_dir / 'chroma'
23
+
24
+ for dir in [data_dir, vectordb_dir, chroma_dir]:
25
+ if not os.path.exists(dir):
26
+ os.mkdir(dir)
27
+
28
+ return data_pickle_path, chroma_dir
29
+
30
+
31
+ def load_data(data_pickle_path, dataset="vipulmaheshwari/GTA-Image-Captioning-Dataset"):
32
+ if not os.path.exists(data_pickle_path):
33
+ print(f"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.")
34
+ data_set = load_dataset(dataset)
35
+ with open(data_pickle_path, 'wb') as outfile:
36
+ pickle.dump(data_set, outfile)
37
+ else:
38
+ print(f"Data set already exists in the local drive. Loading it.")
39
+ with open(data_pickle_path, 'rb') as infile:
40
+ data_set = pickle.load(infile)
41
+
42
+ return data_set
43
+
44
+
45
+ def get_embeddings(data, model):
46
+ # Get the id and embedding of each data/image
47
+ ids = []
48
+ embeddings = []
49
+ for id, image in tqdm(zip(list(range(len(data))), data)):
50
+ ids.append("image "+str(id))
51
+
52
+ embedding = model.encode(image)
53
+ embeddings.append(embedding.tolist())
54
+
55
+ return ids, embeddings
56
+
57
+
58
+ def get_collection(chroma_dir, model, collection_name, data):
59
+ client = chromadb.PersistentClient(path=chroma_dir.__str__())
60
+ collection = client.get_or_create_collection(name=collection_name)
61
+
62
+ if collection.count() != len(data):
63
+ print("Adding embeddings to the collection.")
64
+ ids, embeddings = get_embeddings(data, model)
65
+ collection.add(
66
+ ids=ids,
67
+ embeddings=embeddings
68
+ )
69
+ else:
70
+ print("Embeddings are already added to the collection.")
71
+
72
+ return collection
73
+
74
+
75
+ def get_search_result(collection, data_set, query, model, n_results=2):
76
+ # Query the vector store and get results
77
+ results = collection.query(
78
+ query_embeddings=model.encode([query]),
79
+ n_results=2
80
+ )
81
+
82
+ # Get the id of the most relevant image
83
+ img_id = int(results['ids'][0][0].split('image ')[-1])
84
+
85
+ # Get the image and its caption
86
+ image = data_set['train']['image'][img_id]
87
+ text = data_set['train']['text'][img_id]
88
+
89
+ return image, text
90
+
91
+
92
+ def show_image(image, text, query):
93
+ plt.ion()
94
+ plt.axis("off")
95
+ plt.imshow(image)
96
+ plt.show()
97
+ print(f"User query: {query}")
98
+ print(f"Original description: {text}\n")
99
+
100
+
101
+ def get_logger():
102
+ log_path = "./log/"
103
+ if not os.path.exists(log_path):
104
+ os.mkdir(log_path)
105
+
106
+ cur_date = datetime.utcnow().strftime("%Y%m%d")
107
+ log_filename = f"{log_path}{cur_date}.log"
108
+
109
+ logging.basicConfig(
110
+ filename=log_filename,
111
+ level=logging.INFO,
112
+ format="%(asctime)s %(levelname)-8s %(message)s",
113
+ datefmt="%Y-%m-%d %H:%M:%S")
114
+
115
+ logger = logging.getLogger(__name__)
116
+
117
+ return logger
118
+
119
+
120
+ def get_image_description(image):
121
+ _ = load_dotenv()
122
+ GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
123
+ genai.configure(api_key=GOOGLE_API_KEY)
124
+
125
+ vision_model = genai.GenerativeModel(
126
+ "gemini-pro-vision",
127
+ generation_config={
128
+ "temperature": 0.0
129
+ }
130
+ )
131
+
132
+ # image = Image.open(image_path)
133
+
134
+ prompt = f"""
135
+ Describe what you explicitly see in the given image in detail.
136
+ Begin your description with "In this image," or "This image is about," to provide context.
137
+ Your response should be a hard description of the given image without any thoughts or suggestions.
138
+ """
139
+
140
+ response = vision_model.generate_content([prompt, image])
141
+ description_by_llm = response.text
142
+
143
+ return description_by_llm
144
+
145
+
146
+ def initialization(logger):
147
+ print("Initializing...")
148
+ logger.info("Initializing...")
149
+ print("-------------------------------------------------------")
150
+ logger.info("-------------------------------------------------------")
151
+
152
+ print("Set directories...")
153
+ logger.info("Set directories...")
154
+ # Set directories
155
+ data_pickle_path, chroma_dir = set_directories()
156
+
157
+ print("Loading data...")
158
+ logger.info("Loading data...")
159
+ # Load dataset
160
+ data_set = load_data(data_pickle_path)
161
+
162
+ print("Loading CLIP model...")
163
+ logger.info("Loading CLIP model...")
164
+ # Load CLIP model
165
+ model = SentenceTransformer("sentence-transformers/clip-ViT-L-14")
166
+
167
+ print("Getting vector embeddings...")
168
+ logger.info("Getting vector embeddings...")
169
+ # Get vector embeddings
170
+ collection = get_collection(chroma_dir, model, collection_name='image_vectors', data=data_set['train']['image'])
171
+
172
+ print("-------------------------------------------------------")
173
+ logger.info("-------------------------------------------------------")
174
+ print("Initialization completed! Ready for search.")
175
+ logger.info("Initialization completed! Ready for search.")
176
+
177
  return collection, data_set, model, logger