File size: 664 Bytes
628e563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from trainer import CBOW, TextPreProcessor, make_context_vector

if __name__ == "__main__":
    artist_names = "data/artist-names-per-row.csv"
    model_path = "data/cbow-model-weights"

    text = TextPreProcessor(artist_names)
    vocab = text.build_vocab()
    model = CBOW(vocab)

    model.load_state_dict(torch.load(model_path))
    model.eval()
    print("Loaded model")

    context = ["ana roxanne", "bjork"]
    context_vector = make_context_vector(context, model.word_to_ix)
    a = model(context_vector)
    prediction = model.ix_to_word[torch.argmax(a[0]).item()]
    print(f"Context: {context}\n")
    print(f"Prediction: {prediction}")