waseoke commited on
Commit
dab0bda
·
verified ·
1 Parent(s): fd7be9e

Update calculate_cosine_similarity.py

Browse files
Files changed (1) hide show
  1. calculate_cosine_similarity.py +57 -10
calculate_cosine_similarity.py CHANGED
@@ -2,12 +2,25 @@ from pymongo import MongoClient
2
  import numpy as np
3
  from sklearn.metrics.pairwise import cosine_similarity
4
 
 
5
  client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
6
  db = client["two_tower_model"]
7
  user_embedding_collection = db["user_embeddings"]
 
8
  train_dataset = db["train_dataset"]
9
 
10
- def calculate_similarity(user_id):
 
 
 
 
 
 
 
 
 
 
 
11
  # 사용자 임베딩 가져오기
12
  user_data = user_embedding_collection.find_one({"user_id": user_id})
13
  if not user_data:
@@ -16,21 +29,55 @@ def calculate_similarity(user_id):
16
  user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
17
 
18
  # Anchor 데이터 가져오기
19
- anchor_embeddings = []
20
  train_data = list(train_dataset.find())
21
  for entry in train_data:
 
22
  anchor_embeddings.append(entry["anchor_embedding"])
23
-
24
  anchor_embeddings = np.array(anchor_embeddings)
25
 
26
  # Cosine Similarity 계산
27
- similarities = cosine_similarity(user_embedding, anchor_embeddings).flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # 가장 유사한 anchor 선택
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  most_similar_index = np.argmax(similarities)
31
- most_similar_entry = train_data[most_similar_index]
32
- most_similar_positive = most_similar_entry["positive"]
33
 
34
- print(f"Most similar anchor for user {user_id}: {most_similar_entry['anchor']}")
35
- print(f"Recommended positive product: {most_similar_positive}")
36
- return most_similar_positive
 
2
  import numpy as np
3
  from sklearn.metrics.pairwise import cosine_similarity
4
 
5
+ # MongoDB Atlas 연결
6
  client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
7
  db = client["two_tower_model"]
8
  user_embedding_collection = db["user_embeddings"]
9
+ product_embedding_collection = db["product_embeddings"]
10
  train_dataset = db["train_dataset"]
11
 
12
+ # 유사도 계산 함수
13
+ def calculate_similarity(input_embedding, target_embeddings):
14
+ """
15
+ 입력 임베딩과 대상 임베딩들 간의 cosine similarity를 계산.
16
+ """
17
+ similarities = cosine_similarity(input_embedding, target_embeddings).flatten()
18
+ return similarities
19
+
20
+ def find_most_similar_anchor(user_id):
21
+ """
22
+ 사용자 임베딩을 기준으로 가장 유사한 anchor 상품을 반환.
23
+ """
24
  # 사용자 임베딩 가져오기
25
  user_data = user_embedding_collection.find_one({"user_id": user_id})
26
  if not user_data:
 
29
  user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
30
 
31
  # Anchor 데이터 가져오기
32
+ anchors, anchor_embeddings = [], []
33
  train_data = list(train_dataset.find())
34
  for entry in train_data:
35
+ anchors.append(entry["anchor"])
36
  anchor_embeddings.append(entry["anchor_embedding"])
37
+
38
  anchor_embeddings = np.array(anchor_embeddings)
39
 
40
  # Cosine Similarity 계산
41
+ similarities = calculate_similarity(user_embedding, anchor_embeddings)
42
+ most_similar_index = np.argmax(similarities)
43
+
44
+ return anchors[most_similar_index], anchor_embeddings[most_similar_index]
45
+
46
+ def find_most_similar_product(anchor_embedding):
47
+ """
48
+ Anchor 임베딩과 학습된 임베딩 중 가장 유사한 상품을 반환.
49
+ """
50
+ # Train 데이터의 positive/negative 임베딩과 비교
51
+ train_data = list(train_dataset.find())
52
+ train_embeddings, products = [], []
53
+ for entry in train_data:
54
+ products.extend([entry["positive"], entry["negative"]])
55
+ train_embeddings.extend([entry["positive_embedding"], entry["negative_embedding"]])
56
+
57
+ train_embeddings = np.array(train_embeddings)
58
 
59
+ # Cosine Similarity 계산
60
+ similarities = calculate_similarity(anchor_embedding.reshape(1, -1), train_embeddings)
61
+ most_similar_index = np.argmax(similarities)
62
+
63
+ return products[most_similar_index], train_embeddings[most_similar_index]
64
+
65
+ def recommend_shop_product(similar_product_embedding):
66
+ """
67
+ 유사한 학습된 상품 임베딩과 쇼핑몰 상품 임베딩을 비교하여 추천.
68
+ """
69
+ # 쇼핑몰 상품 임베딩 데이터 가져오기
70
+ all_products = list(product_embedding_collection.find())
71
+ shop_product_embeddings, shop_product_ids = [], []
72
+ for product in all_products:
73
+ shop_product_ids.append(product["product_id"])
74
+ shop_product_embeddings.append(product["embedding"])
75
+
76
+ shop_product_embeddings = np.array(shop_product_embeddings)
77
+
78
+ # Cosine Similarity 계산
79
+ similarities = calculate_similarity(similar_product_embedding.reshape(1, -1), shop_product_embeddings)
80
  most_similar_index = np.argmax(similarities)
 
 
81
 
82
+ return shop_product_ids[most_similar_index]
83
+