File size: 1,058 Bytes
2bffc21
 
 
 
09ec91f
2bffc21
 
09ec91f
2bffc21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

def train_triplet_model(product_model, anchor_data, positive_data, negative_data, num_epochs=10, learning_rate=0.001, margin=1.0):
    optimizer = Adam(product_model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        product_model.train()
        optimizer.zero_grad()

        # Forward pass
        anchor_vec = product_model(anchor_data)
        positive_vec = product_model(positive_data)
        negative_vec = product_model(negative_data)

        # Triplet loss calculation
        positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
        negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
        triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()

        # Backward pass and optimization
        triplet_loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")

    return product_model