Spaces:
Sleeping
Sleeping
from train_model import train_triplet_model | |
from embed_data import embed_product_data, embed_user_data | |
from calculate_similarity import calculate_cosine_similarity | |
from pymongo import MongoClient | |
# MongoDB 연결 | |
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority") | |
db = client["two_tower_model"] | |
product_collection = db["product_tower"] | |
user_collection = db["user_tower"] | |
product_embedding_collection = db["product_embeddings"] | |
user_embedding_collection = db["user_embeddings"] | |
# 모델 학습 | |
def train_model_and_embed(): | |
product_model = None # Define or load your model | |
anchor_data, positive_data, negative_data = load_training_data() | |
trained_model = train_triplet_model(product_model, anchor_data, positive_data, negative_data) | |
return trained_model | |
# 데이터 임베딩 및 저장 | |
def embed_and_save(): | |
all_products = list(product_collection.find()) | |
all_users = list(user_collection.find()) | |
for product_data in all_products: | |
embedding = embed_product_data(product_data) | |
product_embedding_collection.update_one( | |
{"product_id": product_data["product_id"]}, | |
{"$set": {"embedding": embedding.tolist()}}, | |
upsert=True | |
) | |
for user_data in all_users: | |
embedding = embed_user_data(user_data) | |
user_embedding_collection.update_one( | |
{"user_id": user_data["user_id"]}, | |
{"$set": {"embedding": embedding.tolist()}}, | |
upsert=True | |
) | |
# 추천 실행 | |
def recommend(user_id, top_n=5): | |
user_embedding_data = user_embedding_collection.find_one({"user_id": user_id}) | |
if not user_embedding_data: | |
print(f"No embedding found for user_id: {user_id}") | |
return [] | |
user_embedding = np.array(user_embedding_data["embedding"]) | |
all_products = list(product_embedding_collection.find()) | |
product_ids = [prod["product_id"] for prod in all_products] | |
product_embeddings = [prod["embedding"] for prod in all_products] | |
recommendations = calculate_cosine_similarity(user_embedding, product_embeddings, product_ids, top_n) | |
print(f"Recommendations for user {user_id}: {recommendations}") | |
return recommendations | |
# 실행 | |
if __name__ == "__main__": | |
# Train and embed data | |
train_model_and_embed() | |
embed_and_save() | |
# Recommend products for a user | |
user_id = "정우석" | |
recommend(user_id, top_n=3) | |