Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,49 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymongo import MongoClient
|
2 |
+
from transformers import BertTokenizer, BertModel
|
3 |
+
import torch
|
4 |
+
from torch.nn import Embedding
|
5 |
|
6 |
+
# MongoDB Atlas 연결 설정
|
7 |
+
client = MongoClient("mongodb+srv://waseoke:[email protected]/")
|
8 |
+
db = client["패션"]
|
9 |
+
product_collection = db["여성의류"]
|
10 |
+
|
11 |
+
# Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
|
12 |
+
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
|
13 |
+
model = BertModel.from_pretrained("klue/bert-base")
|
14 |
+
|
15 |
+
# 상품 타워: 데이터 임베딩
|
16 |
+
def embed_product_data(product_data):
|
17 |
+
# 상품명과 상세 정보 임베딩 (BERT)
|
18 |
+
text = product_data.get("title", "") + " " + product_data.get("description", "")
|
19 |
+
inputs = tokenizer(
|
20 |
+
text, return_tensors="pt", truncation=True, padding=True, max_length=128
|
21 |
+
)
|
22 |
+
outputs = model(**inputs)
|
23 |
+
text_embedding = outputs.last_hidden_state.mean(dim=1) # 평균 풀링으로 벡터화
|
24 |
+
|
25 |
+
# 카테고리 및 색상 정보 임베딩 (임베딩 레이어)
|
26 |
+
category_embedding_layer = Embedding(num_embeddings=50, embedding_dim=16)
|
27 |
+
color_embedding_layer = Embedding(num_embeddings=20, embedding_dim=8)
|
28 |
+
|
29 |
+
category_id = product_data.get("category_id", 0) # 카테고리 ID, 기본값 0
|
30 |
+
color_id = product_data.get("color_id", 0) # 색상 ID, 기본값 0
|
31 |
+
|
32 |
+
category_embedding = category_embedding_layer(torch.tensor([category_id]))
|
33 |
+
color_embedding = color_embedding_layer(torch.tensor([color_id]))
|
34 |
+
|
35 |
+
# 최종 임베딩 벡터 결합
|
36 |
+
product_embedding = torch.cat(
|
37 |
+
(text_embedding, category_embedding, color_embedding), dim=1
|
38 |
+
)
|
39 |
+
return product_embedding.detach().numpy()
|
40 |
+
|
41 |
+
# MongoDB에서 데이터 가져오기
|
42 |
+
product_data = product_collection.find_one({"product_id": 7}) # 특정 상품 ID
|
43 |
+
|
44 |
+
# 임베딩 수행
|
45 |
+
if product_data:
|
46 |
+
product_embedding = embed_product_data(product_data)
|
47 |
+
print("Product Embedding:", product_embedding)
|
48 |
+
else:
|
49 |
+
print("Product not found.")
|