Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -48,12 +48,10 @@ def embed_product_data(product_data):
|
|
48 |
color_embedding = color_embedding.view(1, -1) # 2D로 변환
|
49 |
|
50 |
# 최종 임베딩 벡터 결합
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
|
55 |
print(f"Generated product_embedding shape: {product_embedding.shape}") # Debugging
|
56 |
-
|
57 |
return product_embedding.detach().numpy()
|
58 |
|
59 |
# 사용자 타워: 데이터 임베딩
|
@@ -76,13 +74,16 @@ def embed_user_data(user_data):
|
|
76 |
scaled_height = (height - min_height) * 99 // (max_height - min_height)
|
77 |
scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
|
78 |
|
79 |
-
age_embedding = embedding_layer(torch.tensor([user_data['age']]))
|
80 |
-
gender_embedding = embedding_layer(torch.tensor([gender_id]))
|
81 |
-
height_embedding = embedding_layer(torch.tensor([scaled_height]))
|
82 |
-
weight_embedding = embedding_layer(torch.tensor([scaled_weight]))
|
83 |
|
84 |
# 최종 임베딩 벡터 결합
|
85 |
-
|
|
|
|
|
|
|
86 |
return user_embedding.detach().numpy()
|
87 |
|
88 |
# MongoDB Atlas에서 데이터 가져오기
|
@@ -127,12 +128,7 @@ def recommend_products_for_user(user_id, top_n=1):
|
|
127 |
print(f"User ID {user_id} embedding not found.")
|
128 |
return []
|
129 |
|
130 |
-
|
131 |
-
user_embedding = np.array(user_embedding_data["embedding"])
|
132 |
-
|
133 |
-
# 차원 확인 및 조정
|
134 |
-
if user_embedding.ndim == 1: # 1D 배열인 경우 2D로 변환
|
135 |
-
user_embedding = user_embedding.reshape(1, -1)
|
136 |
|
137 |
# 모든 상품 임베딩 가져오기
|
138 |
all_product_embeddings = list(product_embedding_collection.find())
|
|
|
48 |
color_embedding = color_embedding.view(1, -1) # 2D로 변환
|
49 |
|
50 |
# 최종 임베딩 벡터 결합
|
51 |
+
combined_embedding = torch.cat((text_embedding, category_embedding, color_embedding), dim=1)
|
52 |
+
product_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
|
|
|
53 |
|
54 |
print(f"Generated product_embedding shape: {product_embedding.shape}") # Debugging
|
|
|
55 |
return product_embedding.detach().numpy()
|
56 |
|
57 |
# 사용자 타워: 데이터 임베딩
|
|
|
74 |
scaled_height = (height - min_height) * 99 // (max_height - min_height)
|
75 |
scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
|
76 |
|
77 |
+
age_embedding = embedding_layer(torch.tensor([user_data['age']])).view(1, -1)
|
78 |
+
gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
|
79 |
+
height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
|
80 |
+
weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
|
81 |
|
82 |
# 최종 임베딩 벡터 결합
|
83 |
+
combined_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
|
84 |
+
user_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
|
85 |
+
|
86 |
+
print(f"Generated user_embedding shape: {reduced_embedding.shape}")
|
87 |
return user_embedding.detach().numpy()
|
88 |
|
89 |
# MongoDB Atlas에서 데이터 가져오기
|
|
|
128 |
print(f"User ID {user_id} embedding not found.")
|
129 |
return []
|
130 |
|
131 |
+
user_embedding = np.array(user_embedding_data["embedding"]).reshape(1, -1)
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# 모든 상품 임베딩 가져오기
|
134 |
all_product_embeddings = list(product_embedding_collection.find())
|