Spaces:
Sleeping
Sleeping
Update calculate_cosine_similarity.py
Browse files- calculate_cosine_similarity.py +57 -10
calculate_cosine_similarity.py
CHANGED
@@ -2,12 +2,25 @@ from pymongo import MongoClient
|
|
2 |
import numpy as np
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
|
|
|
5 |
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
|
6 |
db = client["two_tower_model"]
|
7 |
user_embedding_collection = db["user_embeddings"]
|
|
|
8 |
train_dataset = db["train_dataset"]
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# 사용자 임베딩 가져오기
|
12 |
user_data = user_embedding_collection.find_one({"user_id": user_id})
|
13 |
if not user_data:
|
@@ -16,21 +29,55 @@ def calculate_similarity(user_id):
|
|
16 |
user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
|
17 |
|
18 |
# Anchor 데이터 가져오기
|
19 |
-
anchor_embeddings = []
|
20 |
train_data = list(train_dataset.find())
|
21 |
for entry in train_data:
|
|
|
22 |
anchor_embeddings.append(entry["anchor_embedding"])
|
23 |
-
|
24 |
anchor_embeddings = np.array(anchor_embeddings)
|
25 |
|
26 |
# Cosine Similarity 계산
|
27 |
-
similarities =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
most_similar_index = np.argmax(similarities)
|
31 |
-
most_similar_entry = train_data[most_similar_index]
|
32 |
-
most_similar_positive = most_similar_entry["positive"]
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
return most_similar_positive
|
|
|
2 |
import numpy as np
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
|
5 |
+
# MongoDB Atlas 연결
|
6 |
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
|
7 |
db = client["two_tower_model"]
|
8 |
user_embedding_collection = db["user_embeddings"]
|
9 |
+
product_embedding_collection = db["product_embeddings"]
|
10 |
train_dataset = db["train_dataset"]
|
11 |
|
12 |
+
# 유사도 계산 함수
|
13 |
+
def calculate_similarity(input_embedding, target_embeddings):
|
14 |
+
"""
|
15 |
+
입력 임베딩과 대상 임베딩들 간의 cosine similarity를 계산.
|
16 |
+
"""
|
17 |
+
similarities = cosine_similarity(input_embedding, target_embeddings).flatten()
|
18 |
+
return similarities
|
19 |
+
|
20 |
+
def find_most_similar_anchor(user_id):
|
21 |
+
"""
|
22 |
+
사용자 임베딩을 기준으로 가장 유사한 anchor 상품을 반환.
|
23 |
+
"""
|
24 |
# 사용자 임베딩 가져오기
|
25 |
user_data = user_embedding_collection.find_one({"user_id": user_id})
|
26 |
if not user_data:
|
|
|
29 |
user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
|
30 |
|
31 |
# Anchor 데이터 가져오기
|
32 |
+
anchors, anchor_embeddings = [], []
|
33 |
train_data = list(train_dataset.find())
|
34 |
for entry in train_data:
|
35 |
+
anchors.append(entry["anchor"])
|
36 |
anchor_embeddings.append(entry["anchor_embedding"])
|
37 |
+
|
38 |
anchor_embeddings = np.array(anchor_embeddings)
|
39 |
|
40 |
# Cosine Similarity 계산
|
41 |
+
similarities = calculate_similarity(user_embedding, anchor_embeddings)
|
42 |
+
most_similar_index = np.argmax(similarities)
|
43 |
+
|
44 |
+
return anchors[most_similar_index], anchor_embeddings[most_similar_index]
|
45 |
+
|
46 |
+
def find_most_similar_product(anchor_embedding):
|
47 |
+
"""
|
48 |
+
Anchor 임베딩과 학습된 임베딩 중 가장 유사한 상품을 반환.
|
49 |
+
"""
|
50 |
+
# Train 데이터의 positive/negative 임베딩과 비교
|
51 |
+
train_data = list(train_dataset.find())
|
52 |
+
train_embeddings, products = [], []
|
53 |
+
for entry in train_data:
|
54 |
+
products.extend([entry["positive"], entry["negative"]])
|
55 |
+
train_embeddings.extend([entry["positive_embedding"], entry["negative_embedding"]])
|
56 |
+
|
57 |
+
train_embeddings = np.array(train_embeddings)
|
58 |
|
59 |
+
# Cosine Similarity 계산
|
60 |
+
similarities = calculate_similarity(anchor_embedding.reshape(1, -1), train_embeddings)
|
61 |
+
most_similar_index = np.argmax(similarities)
|
62 |
+
|
63 |
+
return products[most_similar_index], train_embeddings[most_similar_index]
|
64 |
+
|
65 |
+
def recommend_shop_product(similar_product_embedding):
|
66 |
+
"""
|
67 |
+
유사한 학습된 상품 임베딩과 쇼핑몰 상품 임베딩을 비교하여 추천.
|
68 |
+
"""
|
69 |
+
# 쇼핑몰 상품 임베딩 데이터 가져오기
|
70 |
+
all_products = list(product_embedding_collection.find())
|
71 |
+
shop_product_embeddings, shop_product_ids = [], []
|
72 |
+
for product in all_products:
|
73 |
+
shop_product_ids.append(product["product_id"])
|
74 |
+
shop_product_embeddings.append(product["embedding"])
|
75 |
+
|
76 |
+
shop_product_embeddings = np.array(shop_product_embeddings)
|
77 |
+
|
78 |
+
# Cosine Similarity 계산
|
79 |
+
similarities = calculate_similarity(similar_product_embedding.reshape(1, -1), shop_product_embeddings)
|
80 |
most_similar_index = np.argmax(similarities)
|
|
|
|
|
81 |
|
82 |
+
return shop_product_ids[most_similar_index]
|
83 |
+
|
|