Spaces:
Sleeping
Sleeping
File size: 2,452 Bytes
b4a0526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from train_model import train_triplet_model
from embed_data import embed_product_data, embed_user_data
from calculate_similarity import calculate_cosine_similarity
from pymongo import MongoClient
# MongoDB 연결
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
db = client["two_tower_model"]
product_collection = db["product_tower"]
user_collection = db["user_tower"]
product_embedding_collection = db["product_embeddings"]
user_embedding_collection = db["user_embeddings"]
# 모델 학습
def train_model_and_embed():
product_model = None # Define or load your model
anchor_data, positive_data, negative_data = load_training_data()
trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data)
return trained_model
# 데이터 임베딩 및 저장
def embed_and_save():
all_products = list(product_collection.find())
all_users = list(user_collection.find())
for product_data in all_products:
embedding = embed_product_data(product_data)
product_embedding_collection.update_one(
{"product_id": product_data["product_id"]},
{"$set": {"embedding": embedding.tolist()}},
upsert=True
)
for user_data in all_users:
embedding = embed_user_data(user_data)
user_embedding_collection.update_one(
{"user_id": user_data["user_id"]},
{"$set": {"embedding": embedding.tolist()}},
upsert=True
)
# 추천 실행
def recommend(user_id, top_n=5):
user_embedding_data = user_embedding_collection.find_one({"user_id": user_id})
if not user_embedding_data:
print(f"No embedding found for user_id: {user_id}")
return []
user_embedding = np.array(user_embedding_data["embedding"])
all_products = list(product_embedding_collection.find())
product_ids = [prod["product_id"] for prod in all_products]
product_embeddings = [prod["embedding"] for prod in all_products]
recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n)
print(f"Recommendations for user {user_id}: {recommendations}")
return recommendations
# 실행
if __name__ == "__main__":
# Train and embed data
train_model_and_embed()
embed_and_save()
# Recommend products for a user
user_id = "정우석"
recommend(user_id, top_n=3)
|