waseoke commited on
Commit
e1b98b6
·
verified ·
1 Parent(s): 59dd1ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -2
app.py CHANGED
@@ -1,3 +1,49 @@
1
- import streamlit as st
 
 
 
2
 
3
- st.title("Two-Tower Model Demo99999 RF-184 try5")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+ from transformers import BertTokenizer, BertModel
3
+ import torch
4
+ from torch.nn import Embedding
5
 
6
+ # MongoDB Atlas 연결 설정
7
+ client = MongoClient("mongodb+srv://waseoke:[email protected]/")
8
+ db = client["패션"]
9
+ product_collection = db["여성의류"]
10
+
11
+ # Hugging Face의 한국어 BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
12
+ tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
13
+ model = BertModel.from_pretrained("klue/bert-base")
14
+
15
+ # 상품 타워: 데이터 임베딩
16
+ def embed_product_data(product_data):
17
+ # 상품명과 상세 정보 임베딩 (BERT)
18
+ text = product_data.get("title", "") + " " + product_data.get("description", "")
19
+ inputs = tokenizer(
20
+ text, return_tensors="pt", truncation=True, padding=True, max_length=128
21
+ )
22
+ outputs = model(**inputs)
23
+ text_embedding = outputs.last_hidden_state.mean(dim=1) # 평균 풀링으로 벡터화
24
+
25
+ # 카테고리 및 색상 정보 임베딩 (임베딩 레이어)
26
+ category_embedding_layer = Embedding(num_embeddings=50, embedding_dim=16)
27
+ color_embedding_layer = Embedding(num_embeddings=20, embedding_dim=8)
28
+
29
+ category_id = product_data.get("category_id", 0) # 카테고리 ID, 기본값 0
30
+ color_id = product_data.get("color_id", 0) # 색상 ID, 기본값 0
31
+
32
+ category_embedding = category_embedding_layer(torch.tensor([category_id]))
33
+ color_embedding = color_embedding_layer(torch.tensor([color_id]))
34
+
35
+ # 최종 임베딩 벡터 결합
36
+ product_embedding = torch.cat(
37
+ (text_embedding, category_embedding, color_embedding), dim=1
38
+ )
39
+ return product_embedding.detach().numpy()
40
+
41
+ # MongoDB에서 데이터 가져오기
42
+ product_data = product_collection.find_one({"product_id": 7}) # 특정 상품 ID
43
+
44
+ # 임베딩 수행
45
+ if product_data:
46
+ product_embedding = embed_product_data(product_data)
47
+ print("Product Embedding:", product_embedding)
48
+ else:
49
+ print("Product not found.")