shortpingoo / calculate_cosine_similarity.py
waseoke's picture
Update calculate_cosine_similarity.py
dab0bda verified
raw
history blame
3.24 kB
from pymongo import MongoClient
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# MongoDB Atlas 연결
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
db = client["two_tower_model"]
user_embedding_collection = db["user_embeddings"]
product_embedding_collection = db["product_embeddings"]
train_dataset = db["train_dataset"]
# 유사도 계산 함수
def calculate_similarity(input_embedding, target_embeddings):
"""
입력 임베딩과 대상 임베딩들 간의 cosine similarity를 계산.
"""
similarities = cosine_similarity(input_embedding, target_embeddings).flatten()
return similarities
def find_most_similar_anchor(user_id):
"""
사용자 임베딩을 기준으로 가장 유사한 anchor 상품을 반환.
"""
# 사용자 임베딩 가져오기
user_data = user_embedding_collection.find_one({"user_id": user_id})
if not user_data:
raise ValueError(f"No embedding found for user_id: {user_id}")
user_embedding = np.array(user_data["embedding"]).reshape(1, -1)
# Anchor 데이터 가져오기
anchors, anchor_embeddings = [], []
train_data = list(train_dataset.find())
for entry in train_data:
anchors.append(entry["anchor"])
anchor_embeddings.append(entry["anchor_embedding"])
anchor_embeddings = np.array(anchor_embeddings)
# Cosine Similarity 계산
similarities = calculate_similarity(user_embedding, anchor_embeddings)
most_similar_index = np.argmax(similarities)
return anchors[most_similar_index], anchor_embeddings[most_similar_index]
def find_most_similar_product(anchor_embedding):
"""
Anchor 임베딩과 학습된 임베딩 중 가장 유사한 상품을 반환.
"""
# Train 데이터의 positive/negative 임베딩과 비교
train_data = list(train_dataset.find())
train_embeddings, products = [], []
for entry in train_data:
products.extend([entry["positive"], entry["negative"]])
train_embeddings.extend([entry["positive_embedding"], entry["negative_embedding"]])
train_embeddings = np.array(train_embeddings)
# Cosine Similarity 계산
similarities = calculate_similarity(anchor_embedding.reshape(1, -1), train_embeddings)
most_similar_index = np.argmax(similarities)
return products[most_similar_index], train_embeddings[most_similar_index]
def recommend_shop_product(similar_product_embedding):
"""
유사한 학습된 상품 임베딩과 쇼핑몰 상품 임베딩을 비교하여 추천.
"""
# 쇼핑몰 상품 임베딩 데이터 가져오기
all_products = list(product_embedding_collection.find())
shop_product_embeddings, shop_product_ids = [], []
for product in all_products:
shop_product_ids.append(product["product_id"])
shop_product_embeddings.append(product["embedding"])
shop_product_embeddings = np.array(shop_product_embeddings)
# Cosine Similarity 계산
similarities = calculate_similarity(similar_product_embedding.reshape(1, -1), shop_product_embeddings)
most_similar_index = np.argmax(similarities)
return shop_product_ids[most_similar_index]