waseoke commited on
Commit
b4a0526
·
verified ·
1 Parent(s): 0182b00

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -0
main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from train_model import train_triplet_model
2
+ from embed_data import embed_product_data, embed_user_data
3
+ from calculate_similarity import calculate_cosine_similarity
4
+ from pymongo import MongoClient
5
+
6
+ # MongoDB 연결
7
+ client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
8
+ db = client["two_tower_model"]
9
+ product_collection = db["product_tower"]
10
+ user_collection = db["user_tower"]
11
+ product_embedding_collection = db["product_embeddings"]
12
+ user_embedding_collection = db["user_embeddings"]
13
+
14
+ # 모델 학습
15
+ def train_model_and_embed():
16
+ product_model = None # Define or load your model
17
+ anchor_data, positive_data, negative_data = load_training_data()
18
+ trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data)
19
+
20
+ return trained_model
21
+
22
+ # 데이터 임베딩 및 저장
23
+ def embed_and_save():
24
+ all_products = list(product_collection.find())
25
+ all_users = list(user_collection.find())
26
+
27
+ for product_data in all_products:
28
+ embedding = embed_product_data(product_data)
29
+ product_embedding_collection.update_one(
30
+ {"product_id": product_data["product_id"]},
31
+ {"$set": {"embedding": embedding.tolist()}},
32
+ upsert=True
33
+ )
34
+
35
+ for user_data in all_users:
36
+ embedding = embed_user_data(user_data)
37
+ user_embedding_collection.update_one(
38
+ {"user_id": user_data["user_id"]},
39
+ {"$set": {"embedding": embedding.tolist()}},
40
+ upsert=True
41
+ )
42
+
43
+ # 추천 실행
44
+ def recommend(user_id, top_n=5):
45
+ user_embedding_data = user_embedding_collection.find_one({"user_id": user_id})
46
+ if not user_embedding_data:
47
+ print(f"No embedding found for user_id: {user_id}")
48
+ return []
49
+
50
+ user_embedding = np.array(user_embedding_data["embedding"])
51
+ all_products = list(product_embedding_collection.find())
52
+ product_ids = [prod["product_id"] for prod in all_products]
53
+ product_embeddings = [prod["embedding"] for prod in all_products]
54
+
55
+ recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n)
56
+ print(f"Recommendations for user {user_id}: {recommendations}")
57
+ return recommendations
58
+
59
+ # 실행
60
+ if __name__ == "__main__":
61
+ # Train and embed data
62
+ train_model_and_embed()
63
+ embed_and_save()
64
+
65
+ # Recommend products for a user
66
+ user_id = "정우석"
67
+ recommend(user_id, top_n=3)