Spaces:
Sleeping
Sleeping
Create embed_data.py
Browse files- embed_data.py +118 -0
embed_data.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymongo import MongoClient
|
2 |
+
from transformers import BertTokenizer, BertModel
|
3 |
+
import torch
|
4 |
+
from torch.nn import Embedding
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
|
8 |
+
# MongoDB Atlas 연결 설정
|
9 |
+
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
|
10 |
+
db = client["two_tower_model"]
|
11 |
+
product_collection = db["product_tower"]
|
12 |
+
user_collection = db['user_tower']
|
13 |
+
product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
|
14 |
+
user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
|
15 |
+
|
16 |
+
# Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
|
17 |
+
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
|
18 |
+
model = BertModel.from_pretrained("klue/bert-base")
|
19 |
+
|
20 |
+
# Height와 Weight 스케일링에 필요한 값 설정
|
21 |
+
min_height = 50
|
22 |
+
max_height = 250
|
23 |
+
min_weight = 30
|
24 |
+
max_weight = 200
|
25 |
+
|
26 |
+
# 상품 타워: 데이터 임베딩
|
27 |
+
def embed_product_data(product_data):
|
28 |
+
# 상품명과 상세 정보 임베딩 (BERT)
|
29 |
+
text = product_data.get("title", "") + " " + product_data.get("description", "")
|
30 |
+
inputs = tokenizer(
|
31 |
+
text, return_tensors="pt", truncation=True, padding=True, max_length=128
|
32 |
+
)
|
33 |
+
outputs = model(**inputs)
|
34 |
+
text_embedding = outputs.last_hidden_state.mean(dim=1) # 평균 풀링으로 벡터화
|
35 |
+
|
36 |
+
# 카테고리 및 색상 정보 임베딩 (임베딩 레이어)
|
37 |
+
category_embedding_layer = Embedding(num_embeddings=50, embedding_dim=16)
|
38 |
+
color_embedding_layer = Embedding(num_embeddings=20, embedding_dim=8)
|
39 |
+
|
40 |
+
category_id = product_data.get("category_id", 0) # 카테고리 ID, 기본값 0
|
41 |
+
color_id = product_data.get("color_id", 0) # 색상 ID, 기본값 0
|
42 |
+
|
43 |
+
category_embedding = category_embedding_layer(torch.tensor([category_id]))
|
44 |
+
color_embedding = color_embedding_layer(torch.tensor([color_id]))
|
45 |
+
|
46 |
+
# 모든 임베딩 벡터 차원 맞추기
|
47 |
+
category_embedding = category_embedding.view(1, -1) # 2D로 변환
|
48 |
+
color_embedding = color_embedding.view(1, -1) # 2D로 변환
|
49 |
+
|
50 |
+
# 최종 임베딩 벡터 결합
|
51 |
+
combined_embedding = torch.cat((text_embedding, category_embedding, color_embedding), dim=1)
|
52 |
+
product_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
|
53 |
+
|
54 |
+
return product_embedding.detach().numpy()
|
55 |
+
|
56 |
+
# 사용자 타워: 데이터 임베딩
|
57 |
+
def embed_user_data(user_data):
|
58 |
+
# 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
|
59 |
+
embedding_layer = Embedding(num_embeddings=100, embedding_dim=128) # 임의로 설정된 예시 값
|
60 |
+
|
61 |
+
# 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
|
62 |
+
gender_id = 0 if user_data['gender'] == 'M' else 1
|
63 |
+
|
64 |
+
# 스케일링 적용
|
65 |
+
height = user_data['height']
|
66 |
+
weight = user_data['weight']
|
67 |
+
|
68 |
+
if not (min_height <= height <= max_height):
|
69 |
+
raise ValueError(f"Invalid height value: {height}. Expected range: {min_height}-{max_height}")
|
70 |
+
if not (min_weight <= weight <= max_weight):
|
71 |
+
raise ValueError(f"Invalid weight value: {weight}. Expected range: {min_weight}-{max_weight}")
|
72 |
+
|
73 |
+
scaled_height = (height - min_height) * 99 // (max_height - min_height)
|
74 |
+
scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
|
75 |
+
|
76 |
+
age_embedding = embedding_layer(torch.tensor([user_data['age']])).view(1, -1)
|
77 |
+
gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
|
78 |
+
height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
|
79 |
+
weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
|
80 |
+
|
81 |
+
# 최종 임베딩 벡터 결합
|
82 |
+
combined_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
|
83 |
+
user_embedding = torch.nn.functional.adaptive_avg_pool1d(combined_embedding.unsqueeze(0), 512).squeeze(0)
|
84 |
+
|
85 |
+
return user_embedding.detach().numpy()
|
86 |
+
|
87 |
+
# MongoDB Atlas에서 데이터 가져오기
|
88 |
+
all_products = product_collection.find() # 모든 상품 데이터 가져오기
|
89 |
+
all_users = user_collection.find() # 모든 사용자 데이터 가져오기
|
90 |
+
|
91 |
+
# 상품 임베딩 수행
|
92 |
+
for product_data in all_products:
|
93 |
+
product_embedding = embed_product_data(product_data)
|
94 |
+
print(f"Product ID {product_data['product_id']} Embedding: {product_embedding}")
|
95 |
+
|
96 |
+
# MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
|
97 |
+
product_embedding_collection.update_one(
|
98 |
+
{"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
|
99 |
+
{"$set": {"embedding": product_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
|
100 |
+
upsert=True # 기존 항목이 없으면 새로 삽입
|
101 |
+
)
|
102 |
+
print(f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}.")
|
103 |
+
|
104 |
+
# 사용자 임베딩 수행
|
105 |
+
for user_data in all_users:
|
106 |
+
try:
|
107 |
+
user_embedding = embed_user_data(user_data)
|
108 |
+
print(f"User ID {user_data['user_id']} Embedding:", user_embedding)
|
109 |
+
|
110 |
+
# MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
|
111 |
+
user_embedding_collection.update_one(
|
112 |
+
{"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
|
113 |
+
{"$set": {"embedding": user_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
|
114 |
+
upsert=True # 기존 항목이 없으면 새로 삽입
|
115 |
+
)
|
116 |
+
print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
|
117 |
+
except ValueError as e:
|
118 |
+
print(f"Skipping user_id {user_data['user_id']} due to error: {e}")
|