File size: 1,707 Bytes
4e527a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torchvision.transforms as transforms
from PIL import Image


def print_examples(model, device, vocab):
    transform = transforms.Compose(
        [transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    model.eval()

    test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0)
    print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab)))

    test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0)
    print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab)))

    test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0)
    print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab)))

    test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0)
    print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab)))
    
    test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0)
    print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab)))
    model.train()


def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step