Spaces:
Running
Running
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
def print_examples(model, device, vocab): | |
transform = transforms.Compose( | |
[transforms.Resize((299, 299)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
) | |
model.eval() | |
test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0) | |
print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab))) | |
test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0) | |
print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab))) | |
test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0) | |
print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab))) | |
test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0) | |
print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab))) | |
test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0) | |
print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab))) | |
model.train() | |
def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"): | |
print("=> Saving checkpoint") | |
torch.save(state, filename) | |
def load_checkpoint(checkpoint, model, optimizer): | |
print("=> Loading checkpoint") | |
model.load_state_dict(checkpoint["state_dict"]) | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
step = checkpoint["step"] | |
return step | |