waseoke commited on
Commit
ab1f9e3
·
verified ·
1 Parent(s): d9888f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -7
app.py CHANGED
@@ -5,9 +5,11 @@ from torch.nn import Embedding
5
 
6
  # MongoDB Atlas 연결 설정
7
  client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
8
- db = client["패션"]
9
- product_collection = db["여성의류"]
10
- embedding_collection = db["product_embeddings"] # 임베딩을 저장할 컬렉션
 
 
11
 
12
  # Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
13
  tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
@@ -39,10 +41,27 @@ def embed_product_data(product_data):
39
  )
40
  return product_embedding.detach().numpy()
41
 
42
- # MongoDB에서 데이터 가져오기
43
- product_data = product_collection.find_one({"product_id": "1"}) # 특정 상품 ID
 
 
44
 
45
- # 임베딩 수행
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if product_data:
47
  product_embedding = embed_product_data(product_data)
48
  print("Product Embedding:", product_embedding)
@@ -55,4 +74,19 @@ if product_data:
55
  )
56
  print("Embedding saved to MongoDB Atlas based on product_id.")
57
  else:
58
- print("Product not found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # MongoDB Atlas 연결 설정
7
  client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true")
8
+ db = client["two_tower_model"]
9
+ product_collection = db["product_tower"]
10
+ user_collection = db['user_tower']
11
+ product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
12
+ user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
13
 
14
  # Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
15
  tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
 
41
  )
42
  return product_embedding.detach().numpy()
43
 
44
+ # 사용자 타워: 데이터 임베딩
45
+ def embed_user_data(user_data):
46
+ # 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
47
+ embedding_layer = Embedding(num_embeddings=100, embedding_dim=32) # 임의로 설정된 예시 값
48
 
49
+ # 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
50
+ gender_id = 0 if user_data['gender'] == 'M' else 1
51
+ age_embedding = embedding_layer(torch.tensor([user_data['age']]))
52
+ gender_embedding = embedding_layer(torch.tensor([gender_id]))
53
+ height_embedding = embedding_layer(torch.tensor([user_data['height']]))
54
+ weight_embedding = embedding_layer(torch.tensor([user_data['weight']]))
55
+
56
+ # 최종 임베딩 벡터 결합
57
+ user_embedding = torch.cat((age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1)
58
+ return user_embedding.detach().numpy()
59
+
60
+ # MongoDB Atlas에서 데이터 가져오기
61
+ product_data = product_collection.find_one({"product_id": 1}) # 특정 상품 ID
62
+ user_data = user_collection.find_one({'user_id': 1}) # 특정 사용자 ID
63
+
64
+ # 상품 임베딩 수행
65
  if product_data:
66
  product_embedding = embed_product_data(product_data)
67
  print("Product Embedding:", product_embedding)
 
74
  )
75
  print("Embedding saved to MongoDB Atlas based on product_id.")
76
  else:
77
+ print("Product not found.")
78
+
79
+ # 사용자 임베딩 수행
80
+ if user_data:
81
+ user_embedding = embed_user_data(user_data)
82
+ print("User Embedding:", user_embedding)
83
+
84
+ # MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
85
+ embedding_collection.update_one(
86
+ {"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
87
+ {"$set": {"embedding": user_embedding.tolist()}}, # 벡터를 리스트 형태로 저장
88
+ upsert=True # 기존 항목이 없으면 새로 삽입
89
+ )
90
+ print("Embedding saved to MongoDB Atlas based on user_id.")
91
+ else:
92
+ print("User not found.")