File size: 4,380 Bytes
2bffc21
 
 
6c15b41
 
 
 
09ec91f
6c15b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bffc21
09ec91f
2bffc21
 
6c15b41
2bffc21
6c15b41
 
2bffc21
6c15b41
 
 
 
2bffc21
6c15b41
 
 
 
2bffc21
6c15b41
 
 
 
 
 
 
2bffc21
 
6c15b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from pymongo import MongoClient
from transformers import BertTokenizer, BertModel
import numpy as np

# MongoDB Atlas 연결 설정
client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
db = client["two_tower_model"]
train_dataset = db["train_dataset"]

# BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
bert_model = BertModel.from_pretrained("klue/bert-base")

# 상품 임베딩 함수
def embed_product_data(product):
    """
    상품 데이터를 임베딩하는 함수.
    """
    text = product.get("product_name", "") + " " + product.get("product_description", "")
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    outputs = bert_model(**inputs)
    embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()  # 평균 풀링
    return embedding

# PyTorch Dataset 정의
class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32)
        positive = torch.tensor(data["positive_embedding"], dtype=torch.float32)
        negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
        return anchor, positive, negative

# MongoDB에서 데이터셋 로드 및 임베딩 변환
def prepare_training_data():
    dataset = list(train_dataset.find())  # MongoDB에서 데이터를 가져옵니다.
    if not dataset:
        raise ValueError("No training data found in MongoDB.")

    # Anchor, Positive, Negative 임베딩 생성
    embedded_dataset = []
    for entry in dataset:
        try:
            anchor_embedding = embed_product_data(entry["anchor"]["product"])
            positive_embedding = embed_product_data(entry["positive"]["product"])
            negative_embedding = embed_product_data(entry["negative"]["product"])
            embedded_dataset.append({
                "anchor_embedding": anchor_embedding,
                "positive_embedding": positive_embedding,
                "negative_embedding": negative_embedding,
            })
        except Exception as e:
            print(f"Error embedding data: {e}")
    
    return TripletDataset(embedded_dataset)

# Triplet Loss를 학습시키는 함수
def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=1.0):
    optimizer = Adam(product_model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        product_model.train()
        total_loss = 0

        for anchor, positive, negative in train_loader:
            optimizer.zero_grad()

            # Forward pass
            anchor_vec = product_model(anchor)
            positive_vec = product_model(positive)
            negative_vec = product_model(negative)

            # Triplet loss 계산
            positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
            negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
            triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()

            # 역전파와 최적화
            triplet_loss.backward()
            optimizer.step()

            total_loss += triplet_loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

    return product_model

# 모델 학습 파이프라인
def main():
    # 모델 초기화 (예시 모델)
    product_model = torch.nn.Sequential(
        torch.nn.Linear(768, 256),  # 768: BERT 임베딩 차원
        torch.nn.ReLU(),
        torch.nn.Linear(256, 128)
    )

    # 데이터 준비
    triplet_dataset = prepare_training_data()
    train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True)

    # 모델 학습
    trained_model = train_triplet_model(product_model, train_loader)

    # 학습된 모델 저장
    torch.save(trained_model.state_dict(), "product_model.pth")
    print("Model training completed and saved.")

if __name__ == "__main__":
    main()