waseoke commited on
Commit
d917b06
·
verified ·
1 Parent(s): dab0bda

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -65
main.py CHANGED
@@ -1,67 +1,28 @@
1
- from train_model import train_triplet_model
2
- from embed_data import embed_product_data, embed_user_data
3
- from calculate_similarity import calculate_cosine_similarity
4
- from pymongo import MongoClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # MongoDB 연결
7
- client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
8
- db = client["two_tower_model"]
9
- product_collection = db["product_tower"]
10
- user_collection = db["user_tower"]
11
- product_embedding_collection = db["product_embeddings"]
12
- user_embedding_collection = db["user_embeddings"]
13
-
14
- # 모델 학습
15
- def train_model_and_embed():
16
- product_model = None # Define or load your model
17
- anchor_data, positive_data, negative_data = load_training_data()
18
- trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data)
19
-
20
- return trained_model
21
-
22
- # 데이터 임베딩 및 저장
23
- def embed_and_save():
24
- all_products = list(product_collection.find())
25
- all_users = list(user_collection.find())
26
-
27
- for product_data in all_products:
28
- embedding = embed_product_data(product_data)
29
- product_embedding_collection.update_one(
30
- {"product_id": product_data["product_id"]},
31
- {"$set": {"embedding": embedding.tolist()}},
32
- upsert=True
33
- )
34
-
35
- for user_data in all_users:
36
- embedding = embed_user_data(user_data)
37
- user_embedding_collection.update_one(
38
- {"user_id": user_data["user_id"]},
39
- {"$set": {"embedding": embedding.tolist()}},
40
- upsert=True
41
- )
42
-
43
- # 추천 실행
44
- def recommend(user_id, top_n=5):
45
- user_embedding_data = user_embedding_collection.find_one({"user_id": user_id})
46
- if not user_embedding_data:
47
- print(f"No embedding found for user_id: {user_id}")
48
- return []
49
-
50
- user_embedding = np.array(user_embedding_data["embedding"])
51
- all_products = list(product_embedding_collection.find())
52
- product_ids = [prod["product_id"] for prod in all_products]
53
- product_embeddings = [prod["embedding"] for prod in all_products]
54
-
55
- recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n)
56
- print(f"Recommendations for user {user_id}: {recommendations}")
57
- return recommendations
58
-
59
- # 실행
60
  if __name__ == "__main__":
61
- # Train and embed data
62
- train_model_and_embed()
63
- embed_and_save()
64
-
65
- # Recommend products for a user
66
- user_id = "정우석"
67
- recommend(user_id, top_n=3)
 
1
+ import torch
2
+ from calculate_cosine_similarity import (
3
+ find_most_similar_anchor,
4
+ find_most_similar_product,
5
+ recommend_shop_product,
6
+ )
7
+
8
+ def main():
9
+ # 사용자 ID 입력
10
+ user_id = "user_123" # 사용자 ID 예시
11
+
12
+ # Step 1: 사용자와 가장 유사한 anchor 찾기
13
+ print(f"Finding the most similar anchor for user {user_id}...")
14
+ most_similar_anchor, anchor_embedding = find_most_similar_anchor(user_id)
15
+ print(f"Most similar anchor: {most_similar_anchor}")
16
+
17
+ # Step 2: anchor와 가장 유사한 상품 찾기
18
+ print("Finding the most similar product to the anchor...")
19
+ most_similar_product, similar_product_embedding = find_most_similar_product(anchor_embedding)
20
+ print(f"Most similar product to anchor: {most_similar_product}")
21
+
22
+ # Step 3: 쇼핑몰 상품 추천
23
+ print("Recommending the best shop product...")
24
+ recommended_product_id = recommend_shop_product(similar_product_embedding)
25
+ print(f"Recommended shop product ID: {recommended_product_id}")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if __name__ == "__main__":
28
+ main()