File size: 607 Bytes
09ec91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for epoch in range(num_epochs):
    optimizer.zero_grad()
    anchor_vec = product_model(anchor_data)
    positive_vec = product_model(positive_data)
    negative_vec = product_model(negative_data)

    # 트립렛 손실 계산
    positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
    negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
    triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
    
    # 역전파와 최적화
    triplet_loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")