Spaces:
Sleeping
Sleeping
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train_model import train_triplet_model
|
2 |
+
from embed_data import embed_product_data, embed_user_data
|
3 |
+
from calculate_similarity import calculate_cosine_similarity
|
4 |
+
from pymongo import MongoClient
|
5 |
+
|
6 |
+
# MongoDB 연결
|
7 |
+
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
|
8 |
+
db = client["two_tower_model"]
|
9 |
+
product_collection = db["product_tower"]
|
10 |
+
user_collection = db["user_tower"]
|
11 |
+
product_embedding_collection = db["product_embeddings"]
|
12 |
+
user_embedding_collection = db["user_embeddings"]
|
13 |
+
|
14 |
+
# 모델 학습
|
15 |
+
def train_model_and_embed():
|
16 |
+
product_model = None # Define or load your model
|
17 |
+
anchor_data, positive_data, negative_data = load_training_data()
|
18 |
+
trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data)
|
19 |
+
|
20 |
+
return trained_model
|
21 |
+
|
22 |
+
# 데이터 임베딩 및 저장
|
23 |
+
def embed_and_save():
|
24 |
+
all_products = list(product_collection.find())
|
25 |
+
all_users = list(user_collection.find())
|
26 |
+
|
27 |
+
for product_data in all_products:
|
28 |
+
embedding = embed_product_data(product_data)
|
29 |
+
product_embedding_collection.update_one(
|
30 |
+
{"product_id": product_data["product_id"]},
|
31 |
+
{"$set": {"embedding": embedding.tolist()}},
|
32 |
+
upsert=True
|
33 |
+
)
|
34 |
+
|
35 |
+
for user_data in all_users:
|
36 |
+
embedding = embed_user_data(user_data)
|
37 |
+
user_embedding_collection.update_one(
|
38 |
+
{"user_id": user_data["user_id"]},
|
39 |
+
{"$set": {"embedding": embedding.tolist()}},
|
40 |
+
upsert=True
|
41 |
+
)
|
42 |
+
|
43 |
+
# 추천 실행
|
44 |
+
def recommend(user_id, top_n=5):
|
45 |
+
user_embedding_data = user_embedding_collection.find_one({"user_id": user_id})
|
46 |
+
if not user_embedding_data:
|
47 |
+
print(f"No embedding found for user_id: {user_id}")
|
48 |
+
return []
|
49 |
+
|
50 |
+
user_embedding = np.array(user_embedding_data["embedding"])
|
51 |
+
all_products = list(product_embedding_collection.find())
|
52 |
+
product_ids = [prod["product_id"] for prod in all_products]
|
53 |
+
product_embeddings = [prod["embedding"] for prod in all_products]
|
54 |
+
|
55 |
+
recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n)
|
56 |
+
print(f"Recommendations for user {user_id}: {recommendations}")
|
57 |
+
return recommendations
|
58 |
+
|
59 |
+
# 실행
|
60 |
+
if __name__ == "__main__":
|
61 |
+
# Train and embed data
|
62 |
+
train_model_and_embed()
|
63 |
+
embed_and_save()
|
64 |
+
|
65 |
+
# Recommend products for a user
|
66 |
+
user_id = "정우석"
|
67 |
+
recommend(user_id, top_n=3)
|