waseoke commited on
Commit
eb7ba37
·
verified ·
1 Parent(s): cdee98f

Create embed_data.py

Browse files
Files changed (1) hide show
  1. embed_data.py +118 -0
embed_data.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+ from transformers import BertTokenizer, BertModel
3
+ import torch
4
+ from torch.nn import Embedding
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ # MongoDB Atlas 연결 설정
9
+ client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
10
+ db = client["two_tower_model"]
11
+ product_collection = db["product_tower"]
12
+ user_collection = db['user_tower']
13
+ product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
14
+ user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
15
+
16
+ # Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
17
+ tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
18
+ model = BertModel.from_pretrained("klue/bert-base")
19
+
20
+ # Height와 Weight 스케일링에 필요한 값 설정
21
+ min_height = 50
22
+ max_height = 250
23
+ min_weight = 30
24
+ max_weight = 200
25
+
26
+ # 상품 타워: 데이터 임베딩
27
+ def embed_product_data(product_data):
28
+ # 상품명과 상세 정보 임베딩 (BERT)
29
+ text = product_data.get("title", "") + " " + product_data.get("description", "")
30
+ inputs = tokenizer(
31
+ text, return_tensors="pt", truncation=True, padding=True, max_length=128
32
+ )
33
+ outputs = model(**inputs)
34
+ text_embedding = outputs.last_hidden_state.mean(dim=1) # 평균 풀링으로 벡터화
35
+
36
+ # 카테고리 및 색상 정보 임베딩 (임베딩 레이어)
37
+ category_embedding_layer = Embedding(num_embeddings=50, embedding_dim=16)
38
+ color_embedding_layer = Embedding(num_embeddings=20, embedding_dim=8)
39
+
40
+ category_id = product_data.get("category_id", 0) # 카테고리 ID, 기본값 0
41
+ color_id = product_data.get("color_id", 0) # 색상 ID, 기본값 0
42
+
43
+ category_embedding = category_embedding_layer(torch.tensor([category_id]))
44
+ color_embedding = color_embedding_layer(torch.tensor([color_id]))
45
+
46
+ # 모든 임베딩 벡터 차원 맞추기
47
+ category_embedding = category_embedding.view(1, -1) # 2D로 변환
48
+ color_embedding = color_embedding.view(1, -1) # 2D로 변환
49
+
50
+ # 최종 임베딩 벡터 결합
51
+ combined_embedding = torch.cat((text_embedding, category_embedding, color_embedding), dim=1)
52
+ product_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
53
+
54
+ return product_embedding.detach().numpy()
55
+
56
+ # 사용자 타워: 데이터 임베딩
57
+ def embed_user_data(user_data):
58
+ # 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
59
+ embedding_layer = Embedding(num_embeddings=100, embedding_dim=128) # 임의로 설정된 예시 값
60
+
61
+ # 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
62
+ gender_id = 0 if user_data['gender'] == 'M' else 1
63
+
64
+ # 스케일링 적용
65
+ height = user_data['height']
66
+ weight = user_data['weight']
67
+
68
+ if not (min_height <= height <= max_height):
69
+ raise ValueError(f"Invalid height value: {height}. Expected range: {min_height}-{max_height}")
70
+ if not (min_weight <= weight <= max_weight):
71
+ raise ValueError(f"Invalid weight value: {weight}. Expected range: {min_weight}-{max_weight}")
72
+
73
+ scaled_height = (height - min_height) * 99 // (max_height - min_height)
74
+ scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
75
+
76
+ age_embedding = embedding_layer(torch.tensor([user_data['age']])).view(1, -1)
77
+ gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
78
+ height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
79
+ weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
80
+
81
+ # 최종 임베딩 벡터 결합
82
+ combined_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
83
+ user_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
84
+
85
+ return user_embedding.detach().numpy()
86
+
87
+ # MongoDB Atlas에서 데이터 가져오기
88
+ all_products = product_collection.find() # 모든 상품 데이터 가져오기
89
+ all_users = user_collection.find() # 모든 사용자 데이터 가져오기
90
+
91
+ # 상품 임베딩 수행
92
+ for product_data in all_products:
93
+ product_embedding = embed_product_data(product_data)
94
+ print(f"Product ID {product_data['product_id']} Embedding: {product_embedding}")
95
+
96
+ # MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
97
+ product_embedding_collection.update_one(
98
+ {"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
99
+ {"$set": {"embedding": product_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
100
+ upsert=True # 기존 항목이 없으면 새로 삽입
101
+ )
102
+ print(f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}.")
103
+
104
+ # 사용자 임베딩 수행
105
+ for user_data in all_users:
106
+ try:
107
+ user_embedding = embed_user_data(user_data)
108
+ print(f"User ID {user_data['user_id']} Embedding:", user_embedding)
109
+
110
+ # MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
111
+ user_embedding_collection.update_one(
112
+ {"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
113
+ {"$set": {"embedding": user_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
114
+ upsert=True # 기존 항목이 없으면 새로 삽입
115
+ )
116
+ print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
117
+ except ValueError as e:
118
+ print(f"Skipping user_id {user_data['user_id']} due to error: {e}")