shortpingoo / calculate_cosine_similarity.py
waseoke's picture
Update calculate_cosine_similarity.py
fd7be9e verified
raw
history blame
1.41 kB
from pymongo import MongoClient
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
db = client["two_tower_model"]
user_embedding_collection = db["user_embeddings"]
train_dataset = db["train_dataset"]
def calculate_similarity(user_id):
# 사용자 임베딩 가져오기
user_data = user_embedding_collection.find_one({"user_id": user_id})
if not user_data:
raise ValueError(f"No embedding found for user_id: {user_id}")
user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
# Anchor 데이터 가져오기
anchor_embeddings = []
train_data = list(train_dataset.find())
for entry in train_data:
anchor_embeddings.append(entry["anchor_embedding"])
anchor_embeddings = np.array(anchor_embeddings)
# Cosine Similarity 계산
similarities = cosine_similarity(user_embedding, anchor_embeddings).flatten()
# 가장 유사한 anchor 선택
most_similar_index = np.argmax(similarities)
most_similar_entry = train_data[most_similar_index]
most_similar_positive = most_similar_entry["positive"]
print(f"Most similar anchor for user {user_id}: {most_similar_entry['anchor']}")
print(f"Recommended positive product: {most_similar_positive}")
return most_similar_positive