waseoke commited on
Commit
2dc058f
·
verified ·
1 Parent(s): 73952f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -48,12 +48,10 @@ def embed_product_data(product_data):
48
  color_embedding = color_embedding.view(1, -1) # 2D로 변환
49
 
50
  # 최종 임베딩 벡터 결합
51
- product_embedding = torch.cat(
52
- (text_embedding, category_embedding, color_embedding), dim=1
53
- )
54
 
55
  print(f"Generated product_embedding shape: {product_embedding.shape}") # Debugging
56
-
57
  return product_embedding.detach().numpy()
58
 
59
  # 사용자 타워: 데이터 임베딩
@@ -76,13 +74,16 @@ def embed_user_data(user_data):
76
  scaled_height = (height - min_height) * 99 // (max_height - min_height)
77
  scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
78
 
79
- age_embedding = embedding_layer(torch.tensor([user_data['age']]))
80
- gender_embedding = embedding_layer(torch.tensor([gender_id]))
81
- height_embedding = embedding_layer(torch.tensor([scaled_height]))
82
- weight_embedding = embedding_layer(torch.tensor([scaled_weight]))
83
 
84
  # 최종 임베딩 벡터 결합
85
- user_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
 
 
 
86
  return user_embedding.detach().numpy()
87
 
88
  # MongoDB Atlas에서 데이터 가져오기
@@ -127,12 +128,7 @@ def recommend_products_for_user(user_id, top_n=1):
127
  print(f"User ID {user_id} embedding not found.")
128
  return []
129
 
130
- # user_embedding = np.array(user_embedding_data["embedding"]).reshape(1, -1)
131
- user_embedding = np.array(user_embedding_data["embedding"])
132
-
133
- # 차원 확인 및 조정
134
- if user_embedding.ndim == 1: # 1D 배열인 경우 2D로 변환
135
- user_embedding = user_embedding.reshape(1, -1)
136
 
137
  # 모든 상품 임베딩 가져오기
138
  all_product_embeddings = list(product_embedding_collection.find())
 
48
  color_embedding = color_embedding.view(1, -1) # 2D로 변환
49
 
50
  # 최종 임베딩 벡터 결합
51
+ combined_embedding = torch.cat((text_embedding, category_embedding, color_embedding), dim=1)
52
+ product_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
 
53
 
54
  print(f"Generated product_embedding shape: {product_embedding.shape}") # Debugging
 
55
  return product_embedding.detach().numpy()
56
 
57
  # 사용자 타워: 데이터 임베딩
 
74
  scaled_height = (height - min_height) * 99 // (max_height - min_height)
75
  scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
76
 
77
+ age_embedding = embedding_layer(torch.tensor([user_data['age']])).view(1, -1)
78
+ gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
79
+ height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
80
+ weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
81
 
82
  # 최종 임베딩 벡터 결합
83
+ combined_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
84
+ user_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
85
+
86
+ print(f"Generated user_embedding shape: {reduced_embedding.shape}")
87
  return user_embedding.detach().numpy()
88
 
89
  # MongoDB Atlas에서 데이터 가져오기
 
128
  print(f"User ID {user_id} embedding not found.")
129
  return []
130
 
131
+ user_embedding = np.array(user_embedding_data["embedding"]).reshape(1, -1)
 
 
 
 
 
132
 
133
  # 모든 상품 임베딩 가져오기
134
  all_product_embeddings = list(product_embedding_collection.find())