|
import torch |
|
from trainer import CBOW, TextPreProcessor, make_context_vector |
|
|
|
if __name__ == "__main__": |
|
artist_names = "data/artist-names-per-row.csv" |
|
model_path = "data/cbow-model-weights" |
|
|
|
text = TextPreProcessor(artist_names) |
|
vocab = text.build_vocab() |
|
model = CBOW(vocab) |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
model.eval() |
|
print("Loaded model") |
|
|
|
context = ["ana roxanne", "bjork"] |
|
context_vector = make_context_vector(context, model.word_to_ix) |
|
a = model(context_vector) |
|
prediction = model.ix_to_word[torch.argmax(a[0]).item()] |
|
print(f"Context: {context}\n") |
|
print(f"Prediction: {prediction}") |
|
|