Spaces:
Sleeping
Sleeping
File size: 4,380 Bytes
2bffc21 6c15b41 09ec91f 6c15b41 2bffc21 09ec91f 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from pymongo import MongoClient
from transformers import BertTokenizer, BertModel
import numpy as np
# MongoDB Atlas 연결 설정
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
db = client["two_tower_model"]
train_dataset = db["train_dataset"]
# BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
bert_model = BertModel.from_pretrained("klue/bert-base")
# 상품 임베딩 함수
def embed_product_data(product):
"""
상품 데이터를 임베딩하는 함수.
"""
text = product.get("product_name", "") + " " + product.get("product_description", "")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
outputs = bert_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # 평균 풀링
return embedding
# PyTorch Dataset 정의
class TripletDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32)
positive = torch.tensor(data["positive_embedding"], dtype=torch.float32)
negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
return anchor, positive, negative
# MongoDB에서 데이터셋 로드 및 임베딩 변환
def prepare_training_data():
dataset = list(train_dataset.find()) # MongoDB에서 데이터를 가져옵니다.
if not dataset:
raise ValueError("No training data found in MongoDB.")
# Anchor, Positive, Negative 임베딩 생성
embedded_dataset = []
for entry in dataset:
try:
anchor_embedding = embed_product_data(entry["anchor"]["product"])
positive_embedding = embed_product_data(entry["positive"]["product"])
negative_embedding = embed_product_data(entry["negative"]["product"])
embedded_dataset.append({
"anchor_embedding": anchor_embedding,
"positive_embedding": positive_embedding,
"negative_embedding": negative_embedding,
})
except Exception as e:
print(f"Error embedding data: {e}")
return TripletDataset(embedded_dataset)
# Triplet Loss를 학습시키는 함수
def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=1.0):
optimizer = Adam(product_model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
product_model.train()
total_loss = 0
for anchor, positive, negative in train_loader:
optimizer.zero_grad()
# Forward pass
anchor_vec = product_model(anchor)
positive_vec = product_model(positive)
negative_vec = product_model(negative)
# Triplet loss 계산
positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
# 역전파와 최적화
triplet_loss.backward()
optimizer.step()
total_loss += triplet_loss.item()
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")
return product_model
# 모델 학습 파이프라인
def main():
# 모델 초기화 (예시 모델)
product_model = torch.nn.Sequential(
torch.nn.Linear(768, 256), # 768: BERT 임베딩 차원
torch.nn.ReLU(),
torch.nn.Linear(256, 128)
)
# 데이터 준비
triplet_dataset = prepare_training_data()
train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True)
# 모델 학습
trained_model = train_triplet_model(product_model, train_loader)
# 학습된 모델 저장
torch.save(trained_model.state_dict(), "product_model.pth")
print("Model training completed and saved.")
if __name__ == "__main__":
main()
|