File size: 1,050 Bytes
9e582c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from src.translator import Translator
import torch
import random
from src.helper import get_data

random.seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {
    "embed_size": 16,
    "hidden_size": 512,
    "cell_type": "LSTM",
    "num_layers": 2,
    "dropout": 0.1,
    "learning_rate": 0.005,
    "optimizer": "SGD",
    "teacher_forcing_ratio": 0.5,
    "max_length": 50
}

model = Translator("tam", params, device)

model.encoder.load_state_dict(torch.load("./best_model_vanilla/encoder.pt"))
model.decoder.load_state_dict(torch.load("./best_model_vanilla/decoder.pt"))

with open("test_gen.txt", "w") as f:
    test_data = get_data("tam", "test")
    f.write("Input, Target, Output\n")
    accuracy = 0
    for i in range(len(test_data)):
        f.write(test_data[i][0] + ", " + test_data[i][1] + ", " + model.evaluate(test_data[i][0]) + "\n")
        if test_data[i][1] == model.evaluate(test_data[i][0]):
            accuracy += 1

    print("Test Accuracy: " + str(accuracy/len(test_data) * 100) + "%")