Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,11 @@ from torch.nn import Embedding
|
|
5 |
|
6 |
# MongoDB Atlas 연결 설정
|
7 |
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
|
8 |
-
db = client["
|
9 |
-
product_collection = db["
|
10 |
-
|
|
|
|
|
11 |
|
12 |
# Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
|
13 |
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
|
@@ -39,10 +41,27 @@ def embed_product_data(product_data):
|
|
39 |
)
|
40 |
return product_embedding.detach().numpy()
|
41 |
|
42 |
-
#
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
if product_data:
|
47 |
product_embedding = embed_product_data(product_data)
|
48 |
print("Product Embedding:", product_embedding)
|
@@ -55,4 +74,19 @@ if product_data:
|
|
55 |
)
|
56 |
print("Embedding saved to MongoDB Atlas based on product_id.")
|
57 |
else:
|
58 |
-
print("Product not found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# MongoDB Atlas 연결 설정
|
7 |
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
|
8 |
+
db = client["two_tower_model"]
|
9 |
+
product_collection = db["product_tower"]
|
10 |
+
user_collection = db['user_tower']
|
11 |
+
product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
|
12 |
+
user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
|
13 |
|
14 |
# Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
|
15 |
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
|
|
|
41 |
)
|
42 |
return product_embedding.detach().numpy()
|
43 |
|
44 |
+
# 사용자 타워: 데이터 임베딩
|
45 |
+
def embed_user_data(user_data):
|
46 |
+
# 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
|
47 |
+
embedding_layer = Embedding(num_embeddings=100, embedding_dim=32) # 임의로 설정된 예시 값
|
48 |
|
49 |
+
# 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
|
50 |
+
gender_id = 0 if user_data['gender'] == 'M' else 1
|
51 |
+
age_embedding = embedding_layer(torch.tensor([user_data['age']]))
|
52 |
+
gender_embedding = embedding_layer(torch.tensor([gender_id]))
|
53 |
+
height_embedding = embedding_layer(torch.tensor([user_data['height']]))
|
54 |
+
weight_embedding = embedding_layer(torch.tensor([user_data['weight']]))
|
55 |
+
|
56 |
+
# 최종 임베딩 벡터 결합
|
57 |
+
user_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
|
58 |
+
return user_embedding.detach().numpy()
|
59 |
+
|
60 |
+
# MongoDB Atlas에서 데이터 가져오기
|
61 |
+
product_data = product_collection.find_one({"product_id": 1}) # 특정 상품 ID
|
62 |
+
user_data = user_collection.find_one({'user_id': 1}) # 특정 사용자 ID
|
63 |
+
|
64 |
+
# 상품 임베딩 수행
|
65 |
if product_data:
|
66 |
product_embedding = embed_product_data(product_data)
|
67 |
print("Product Embedding:", product_embedding)
|
|
|
74 |
)
|
75 |
print("Embedding saved to MongoDB Atlas based on product_id.")
|
76 |
else:
|
77 |
+
print("Product not found.")
|
78 |
+
|
79 |
+
# 사용자 임베딩 수행
|
80 |
+
if user_data:
|
81 |
+
user_embedding = embed_user_data(user_data)
|
82 |
+
print("User Embedding:", user_embedding)
|
83 |
+
|
84 |
+
# MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
|
85 |
+
embedding_collection.update_one(
|
86 |
+
{"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
|
87 |
+
{"$set": {"embedding": user_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
|
88 |
+
upsert=True # 기존 항목이 없으면 새로 삽입
|
89 |
+
)
|
90 |
+
print("Embedding saved to MongoDB Atlas based on user_id.")
|
91 |
+
else:
|
92 |
+
print("User not found.")
|