from train_model import train_triplet_model from embed_data import embed_product_data, embed_user_data from calculate_similarity import calculate_cosine_similarity from pymongo import MongoClient # MongoDB 연결 client = MongoClient("mongodb+srv://waseoke:rookies3@cluster0.ps7gq.mongodb.net/test?retryWrites=true&w=majority") db = client["two_tower_model"] product_collection = db["product_tower"] user_collection = db["user_tower"] product_embedding_collection = db["product_embeddings"] user_embedding_collection = db["user_embeddings"] # 모델 학습 def train_model_and_embed(): product_model = None # Define or load your model anchor_data, positive_data, negative_data = load_training_data() trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data) return trained_model # 데이터 임베딩 및 저장 def embed_and_save(): all_products = list(product_collection.find()) all_users = list(user_collection.find()) for product_data in all_products: embedding = embed_product_data(product_data) product_embedding_collection.update_one( {"product_id": product_data["product_id"]}, {"$set": {"embedding": embedding.tolist()}}, upsert=True ) for user_data in all_users: embedding = embed_user_data(user_data) user_embedding_collection.update_one( {"user_id": user_data["user_id"]}, {"$set": {"embedding": embedding.tolist()}}, upsert=True ) # 추천 실행 def recommend(user_id, top_n=5): user_embedding_data = user_embedding_collection.find_one({"user_id": user_id}) if not user_embedding_data: print(f"No embedding found for user_id: {user_id}") return [] user_embedding = np.array(user_embedding_data["embedding"]) all_products = list(product_embedding_collection.find()) product_ids = [prod["product_id"] for prod in all_products] product_embeddings = [prod["embedding"] for prod in all_products] recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n) print(f"Recommendations for user {user_id}: {recommendations}") return recommendations # 실행 if __name__ == "__main__": # Train and embed data train_model_and_embed() embed_and_save() # Recommend products for a user user_id = "정우석" recommend(user_id, top_n=3)