waseoke commited on
Commit
6c15b41
·
verified ·
1 Parent(s): b4a0526

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +105 -15
train_model.py CHANGED
@@ -1,29 +1,119 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from torch.optim import Adam
4
- from torch.utils.data import DataLoader
 
 
 
5
 
6
- def train_triplet_model(product_model, anchor_data, positive_data, negative_data, num_epochs=10, learning_rate=0.001, margin=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  optimizer = Adam(product_model.parameters(), lr=learning_rate)
8
 
9
  for epoch in range(num_epochs):
10
  product_model.train()
11
- optimizer.zero_grad()
12
 
13
- # Forward pass
14
- anchor_vec = product_model(anchor_data)
15
- positive_vec = product_model(positive_data)
16
- negative_vec = product_model(negative_data)
17
 
18
- # Triplet loss calculation
19
- positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
20
- negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
21
- triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
22
 
23
- # Backward pass and optimization
24
- triplet_loss.backward()
25
- optimizer.step()
 
26
 
27
- print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")
 
 
 
 
 
 
28
 
29
  return product_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from torch.optim import Adam
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from pymongo import MongoClient
6
+ from transformers import BertTokenizer, BertModel
7
+ import numpy as np
8
 
9
+ # MongoDB Atlas 연결 설정
10
+ client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
11
+ db = client["two_tower_model"]
12
+ train_dataset = db["train_dataset"]
13
+
14
+ # BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
15
+ tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
16
+ bert_model = BertModel.from_pretrained("klue/bert-base")
17
+
18
+ # 상품 임베딩 함수
19
+ def embed_product_data(product):
20
+ """
21
+ 상품 데이터를 임베딩하는 함수.
22
+ """
23
+ text = product.get("product_name", "") + " " + product.get("product_description", "")
24
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
25
+ outputs = bert_model(**inputs)
26
+ embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # 평균 풀링
27
+ return embedding
28
+
29
+ # PyTorch Dataset 정의
30
+ class TripletDataset(Dataset):
31
+ def __init__(self, dataset):
32
+ self.dataset = dataset
33
+
34
+ def __len__(self):
35
+ return len(self.dataset)
36
+
37
+ def __getitem__(self, idx):
38
+ data = self.dataset[idx]
39
+ anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32)
40
+ positive = torch.tensor(data["positive_embedding"], dtype=torch.float32)
41
+ negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
42
+ return anchor, positive, negative
43
+
44
+ # MongoDB에서 데이터셋 로드 및 임베딩 변환
45
+ def prepare_training_data():
46
+ dataset = list(train_dataset.find()) # MongoDB에서 데이터를 가져옵니다.
47
+ if not dataset:
48
+ raise ValueError("No training data found in MongoDB.")
49
+
50
+ # Anchor, Positive, Negative 임베딩 생성
51
+ embedded_dataset = []
52
+ for entry in dataset:
53
+ try:
54
+ anchor_embedding = embed_product_data(entry["anchor"]["product"])
55
+ positive_embedding = embed_product_data(entry["positive"]["product"])
56
+ negative_embedding = embed_product_data(entry["negative"]["product"])
57
+ embedded_dataset.append({
58
+ "anchor_embedding": anchor_embedding,
59
+ "positive_embedding": positive_embedding,
60
+ "negative_embedding": negative_embedding,
61
+ })
62
+ except Exception as e:
63
+ print(f"Error embedding data: {e}")
64
+
65
+ return TripletDataset(embedded_dataset)
66
+
67
+ # Triplet Loss를 학습시키는 함수
68
+ def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=1.0):
69
  optimizer = Adam(product_model.parameters(), lr=learning_rate)
70
 
71
  for epoch in range(num_epochs):
72
  product_model.train()
73
+ total_loss = 0
74
 
75
+ for anchor, positive, negative in train_loader:
76
+ optimizer.zero_grad()
 
 
77
 
78
+ # Forward pass
79
+ anchor_vec = product_model(anchor)
80
+ positive_vec = product_model(positive)
81
+ negative_vec = product_model(negative)
82
 
83
+ # Triplet loss 계산
84
+ positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
85
+ negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
86
+ triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
87
 
88
+ # 역전파와 최적화
89
+ triplet_loss.backward()
90
+ optimizer.step()
91
+
92
+ total_loss += triplet_loss.item()
93
+
94
+ print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")
95
 
96
  return product_model
97
+
98
+ # 모델 학습 파이프라인
99
+ def main():
100
+ # 모델 초기화 (예시 모델)
101
+ product_model = torch.nn.Sequential(
102
+ torch.nn.Linear(768, 256), # 768: BERT 임베딩 차원
103
+ torch.nn.ReLU(),
104
+ torch.nn.Linear(256, 128)
105
+ )
106
+
107
+ # 데이터 준비
108
+ triplet_dataset = prepare_training_data()
109
+ train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True)
110
+
111
+ # 모델 학습
112
+ trained_model = train_triplet_model(product_model, train_loader)
113
+
114
+ # 학습된 모델 저장
115
+ torch.save(trained_model.state_dict(), "product_model.pth")
116
+ print("Model training completed and saved.")
117
+
118
+ if __name__ == "__main__":
119
+ main()