Thor Kell
add python code
628e563
raw
history blame contribute delete
664 Bytes
import torch
from trainer import CBOW, TextPreProcessor, make_context_vector
if __name__ == "__main__":
artist_names = "data/artist-names-per-row.csv"
model_path = "data/cbow-model-weights"
text = TextPreProcessor(artist_names)
vocab = text.build_vocab()
model = CBOW(vocab)
model.load_state_dict(torch.load(model_path))
model.eval()
print("Loaded model")
context = ["ana roxanne", "bjork"]
context_vector = make_context_vector(context, model.word_to_ix)
a = model(context_vector)
prediction = model.ix_to_word[torch.argmax(a[0]).item()]
print(f"Context: {context}\n")
print(f"Prediction: {prediction}")