waseoke commited on
Commit
09ec91f
·
verified ·
1 Parent(s): d0b0df6

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +16 -0
train_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for epoch in range(num_epochs):
2
+ optimizer.zero_grad()
3
+ anchor_vec = product_model(anchor_data)
4
+ positive_vec = product_model(positive_data)
5
+ negative_vec = product_model(negative_data)
6
+
7
+ # 트립렛 손실 계산
8
+ positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
9
+ negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
10
+ triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
11
+
12
+ # 역전파와 최적화
13
+ triplet_loss.backward()
14
+ optimizer.step()
15
+
16
+ print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")