waseoke commited on
Commit
e6dc15e
·
verified ·
1 Parent(s): bb06e40

Update embed_data.py

Browse files
Files changed (1) hide show
  1. embed_data.py +47 -24
embed_data.py CHANGED
@@ -2,14 +2,14 @@ from pymongo import MongoClient
2
  from transformers import BertTokenizer, BertModel
3
  import torch
4
  from torch.nn import Embedding
5
- import numpy as np
6
- from sklearn.metrics.pairwise import cosine_similarity
7
 
8
  # MongoDB Atlas 연결 설정
9
- client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
 
 
10
  db = client["two_tower_model"]
11
  product_collection = db["product_tower"]
12
- user_collection = db['user_tower']
13
  product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
14
  user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
15
 
@@ -23,6 +23,7 @@ max_height = 250
23
  min_weight = 30
24
  max_weight = 200
25
 
 
26
  # 상품 타워: 데이터 임베딩
27
  def embed_product_data(product_data):
28
  # 상품명과 상세 정보 임베딩 (BERT)
@@ -45,47 +46,63 @@ def embed_product_data(product_data):
45
 
46
  # 모든 임베딩 벡터 차원 맞추기
47
  category_embedding = category_embedding.view(1, -1) # 2D로 변환
48
- color_embedding = color_embedding.view(1, -1) # 2D로 변환
49
 
50
  # 최종 임베딩 벡터 결합
51
- combined_embedding = torch.cat((text_embedding, category_embedding, color_embedding), dim=1)
52
- product_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
 
 
 
 
53
 
54
  return product_embedding.detach().numpy()
55
 
 
56
  # 사용자 타워: 데이터 임베딩
57
  def embed_user_data(user_data):
58
  # 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
59
- embedding_layer = Embedding(num_embeddings=100, embedding_dim=128) # 임의로 설정된 예시 값
 
 
60
 
61
  # 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
62
- gender_id = 0 if user_data['gender'] == 'M' else 1
63
 
64
  # 스케일링 적용
65
- height = user_data['height']
66
- weight = user_data['weight']
67
 
68
  if not (min_height <= height <= max_height):
69
- raise ValueError(f"Invalid height value: {height}. Expected range: {min_height}-{max_height}")
 
 
70
  if not (min_weight <= weight <= max_weight):
71
- raise ValueError(f"Invalid weight value: {weight}. Expected range: {min_weight}-{max_weight}")
 
 
72
 
73
  scaled_height = (height - min_height) * 99 // (max_height - min_height)
74
  scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
75
-
76
- age_embedding = embedding_layer(torch.tensor([user_data['age']])).view(1, -1)
77
  gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
78
  height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
79
  weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
80
 
81
  # 최종 임베딩 벡터 결합
82
- combined_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
83
- user_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
 
 
 
 
84
 
85
  return user_embedding.detach().numpy()
86
 
 
87
  # MongoDB Atlas에서 데이터 가져오기
88
- all_products = product_collection.find() # 모든 상품 데이터 가져오기
89
  all_users = user_collection.find() # 모든 사용자 데이터 가져오기
90
 
91
  # 상품 임베딩 수행
@@ -96,10 +113,14 @@ for product_data in all_products:
96
  # MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
97
  product_embedding_collection.update_one(
98
  {"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
99
- {"$set": {"embedding": product_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
100
- upsert=True # 기존 항목이 없으면 새로 삽입
 
 
 
 
 
101
  )
102
- print(f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}.")
103
 
104
  # 사용자 임베딩 수행
105
  for user_data in all_users:
@@ -110,9 +131,11 @@ for user_data in all_users:
110
  # MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
111
  user_embedding_collection.update_one(
112
  {"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
113
- {"$set": {"embedding": user_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
114
- upsert=True # 기존 항목이 없으면 새로 삽입
 
 
115
  )
116
  print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
117
  except ValueError as e:
118
- print(f"Skipping user_id {user_data['user_id']} due to error: {e}")
 
2
  from transformers import BertTokenizer, BertModel
3
  import torch
4
  from torch.nn import Embedding
 
 
5
 
6
  # MongoDB Atlas 연결 설정
7
+ client = MongoClient(
8
+ "mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
9
+ )
10
  db = client["two_tower_model"]
11
  product_collection = db["product_tower"]
12
+ user_collection = db["user_tower"]
13
  product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
14
  user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
15
 
 
23
  min_weight = 30
24
  max_weight = 200
25
 
26
+
27
  # 상품 타워: 데이터 임베딩
28
  def embed_product_data(product_data):
29
  # 상품명과 상세 정보 임베딩 (BERT)
 
46
 
47
  # 모든 임베딩 벡터 차원 맞추기
48
  category_embedding = category_embedding.view(1, -1) # 2D로 변환
49
+ color_embedding = color_embedding.view(1, -1) # 2D로 변환
50
 
51
  # 최종 임베딩 벡터 결합
52
+ combined_embedding = torch.cat(
53
+ (text_embedding, category_embedding, color_embedding), dim=1
54
+ )
55
+ product_embedding = torch.nn.functional.adaptive_avg_pool1d(
56
+ combined_embedding.unsqueeze(0), 512
57
+ ).squeeze(0)
58
 
59
  return product_embedding.detach().numpy()
60
 
61
+
62
  # 사용자 타워: 데이터 임베딩
63
  def embed_user_data(user_data):
64
  # 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
65
+ embedding_layer = Embedding(
66
+ num_embeddings=100, embedding_dim=128
67
+ ) # 임의로 설정된 예시 값
68
 
69
  # 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
70
+ gender_id = 0 if user_data["gender"] == "M" else 1
71
 
72
  # 스케일링 적용
73
+ height = user_data["height"]
74
+ weight = user_data["weight"]
75
 
76
  if not (min_height <= height <= max_height):
77
+ raise ValueError(
78
+ f"Invalid height value: {height}. Expected range: {min_height}-{max_height}"
79
+ )
80
  if not (min_weight <= weight <= max_weight):
81
+ raise ValueError(
82
+ f"Invalid weight value: {weight}. Expected range: {min_weight}-{max_weight}"
83
+ )
84
 
85
  scaled_height = (height - min_height) * 99 // (max_height - min_height)
86
  scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
87
+
88
+ age_embedding = embedding_layer(torch.tensor([user_data["age"]])).view(1, -1)
89
  gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
90
  height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
91
  weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
92
 
93
  # 최종 임베딩 벡터 결합
94
+ combined_embedding = torch.cat(
95
+ (age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1
96
+ )
97
+ user_embedding = torch.nn.functional.adaptive_avg_pool1d(
98
+ combined_embedding.unsqueeze(0), 512
99
+ ).squeeze(0)
100
 
101
  return user_embedding.detach().numpy()
102
 
103
+
104
  # MongoDB Atlas에서 데이터 가져오기
105
+ all_products = product_collection.find() # 모든 상품 데이터 가져오기
106
  all_users = user_collection.find() # 모든 사용자 데이터 가져오기
107
 
108
  # 상품 임베딩 수행
 
113
  # MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
114
  product_embedding_collection.update_one(
115
  {"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
116
+ {
117
+ "$set": {"embedding": product_embedding.tolist()}
118
+ }, # 벡터를 리스트 형태로 저장
119
+ upsert=True, # 기존 항목이 없으면 새로 삽입
120
+ )
121
+ print(
122
+ f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}."
123
  )
 
124
 
125
  # 사용자 임베딩 수행
126
  for user_data in all_users:
 
131
  # MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
132
  user_embedding_collection.update_one(
133
  {"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
134
+ {
135
+ "$set": {"embedding": user_embedding.tolist()}
136
+ }, # 벡터를 리스트 형태로 저장
137
+ upsert=True, # 기존 항목이 없으면 새로 삽입
138
  )
139
  print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
140
  except ValueError as e:
141
+ print(f"Skipping user_id {user_data['user_id']} due to error: {e}")