waseoke commited on
Commit
fd7be9e
·
verified ·
1 Parent(s): 5d7babf

Update calculate_cosine_similarity.py

Browse files
Files changed (1) hide show
  1. calculate_cosine_similarity.py +32 -8
calculate_cosine_similarity.py CHANGED
@@ -1,12 +1,36 @@
1
- from sklearn.metrics.pairwise import cosine_similarity
2
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n=5):
5
- user_embedding = user_embedding.reshape(1, -1)
6
- product_embeddings = np.array(product_embeddings)
7
 
8
- similarities = cosine_similarity(user_embedding, product_embeddings).flatten()
9
- top_indices = similarities.argsort()[::-1][:top_n]
10
- recommendations = [(product_ids[i], similarities[i]) for i in top_indices]
 
11
 
12
- return recommendations
 
 
 
1
+ from pymongo import MongoClient
2
  import numpy as np
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
6
+ db = client["two_tower_model"]
7
+ user_embedding_collection = db["user_embeddings"]
8
+ train_dataset = db["train_dataset"]
9
+
10
+ def calculate_similarity(user_id):
11
+ # 사용자 임베딩 가져오기
12
+ user_data = user_embedding_collection.find_one({"user_id": user_id})
13
+ if not user_data:
14
+ raise ValueError(f"No embedding found for user_id: {user_id}")
15
+
16
+ user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
17
+
18
+ # Anchor 데이터 가져오기
19
+ anchor_embeddings = []
20
+ train_data = list(train_dataset.find())
21
+ for entry in train_data:
22
+ anchor_embeddings.append(entry["anchor_embedding"])
23
+
24
+ anchor_embeddings = np.array(anchor_embeddings)
25
 
26
+ # Cosine Similarity 계산
27
+ similarities = cosine_similarity(user_embedding, anchor_embeddings).flatten()
 
28
 
29
+ # 가장 유사한 anchor 선택
30
+ most_similar_index = np.argmax(similarities)
31
+ most_similar_entry = train_data[most_similar_index]
32
+ most_similar_positive = most_similar_entry["positive"]
33
 
34
+ print(f"Most similar anchor for user {user_id}: {most_similar_entry['anchor']}")
35
+ print(f"Recommended positive product: {most_similar_positive}")
36
+ return most_similar_positive