Thor Kell commited on
Commit
628e563
·
1 Parent(s): 772edf4

add python code

Browse files
Files changed (3) hide show
  1. parse_tracklists.py +68 -0
  2. runner.py +21 -0
  3. trainer.py +125 -0
parse_tracklists.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import re
3
+
4
+
5
+ def load_lines(filename):
6
+ lines = []
7
+ with open(filename) as f:
8
+ for line in f:
9
+ lines.append(line.strip())
10
+ return lines
11
+
12
+
13
+ def remove_titles_and_bad_tracks(lines):
14
+ is_track = re.compile(r"^\d.*")
15
+ better_lines = []
16
+ for line in lines:
17
+ if is_track.match(line) and "???" not in line:
18
+ better_lines.append(line)
19
+ return better_lines
20
+
21
+
22
+ def group_by_set(lines):
23
+ is_set_title = re.compile(r".*:$")
24
+ is_track = re.compile(r"^\d.*:")
25
+ grouped_lines = []
26
+
27
+ current_set = []
28
+ for line in lines:
29
+ if not line.strip():
30
+ continue
31
+ if is_set_title.match(line) and len(current_set) > 0:
32
+ grouped_lines.append(current_set)
33
+ current_set = []
34
+ elif is_track.match(line) and "???" not in line:
35
+ current_set.append(line)
36
+
37
+ return grouped_lines
38
+
39
+
40
+ def get_grouped_artists(grouped_lines):
41
+ artist_from_track = re.compile(r"\d+\: (.+?) - .+?")
42
+ artist_names = []
43
+ for dj_set_lines in grouped_lines:
44
+ dj_set_artists = []
45
+ for line in dj_set_lines:
46
+ if artist_match := artist_from_track.match(line):
47
+ artist_name = artist_match.group(1).strip().lower()
48
+ dj_set_artists.append(artist_name)
49
+ artist_names.append(dj_set_artists)
50
+
51
+ return artist_names
52
+
53
+
54
+ def write_to_csv(filename):
55
+ with open(output_filename, "w", newline="") as csvfile:
56
+ writer = csv.writer(csvfile)
57
+ for artists in artist_names:
58
+ writer.writerow(artists)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ filename = "data/radio-original.txt"
63
+ output_filename = "data/artist-names-per-row.csv"
64
+
65
+ lines = load_lines(filename)
66
+ grouped_lines = group_by_set(lines)
67
+ artist_names = get_grouped_artists(grouped_lines)
68
+ write_to_csv(output_filename)
runner.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from trainer import CBOW, TextPreProcessor, make_context_vector
3
+
4
+ if __name__ == "__main__":
5
+ artist_names = "data/artist-names-per-row.csv"
6
+ model_path = "data/cbow-model-weights"
7
+
8
+ text = TextPreProcessor(artist_names)
9
+ vocab = text.build_vocab()
10
+ model = CBOW(vocab)
11
+
12
+ model.load_state_dict(torch.load(model_path))
13
+ model.eval()
14
+ print("Loaded model")
15
+
16
+ context = ["ana roxanne", "bjork"]
17
+ context_vector = make_context_vector(context, model.word_to_ix)
18
+ a = model(context_vector)
19
+ prediction = model.ix_to_word[torch.argmax(a[0]).item()]
20
+ print(f"Context: {context}\n")
21
+ print(f"Prediction: {prediction}")
trainer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch
3
+ from torchtext.vocab import build_vocab_from_iterator
4
+
5
+
6
+ class TextPreProcessor:
7
+ def __init__(self, input_file):
8
+ self.input_file = input_file
9
+ self.context_size = 1
10
+
11
+ def build_training_data(self):
12
+ data = []
13
+ for row in self._generate_rows():
14
+ for i in range(self.context_size, len(row) - self.context_size):
15
+ before = row[i - 1].lower()
16
+ target = row[i].lower()
17
+ after = row[i + 1].lower()
18
+
19
+ context_one = [before, after]
20
+ context_two = [after, before]
21
+ data.append((context_one, target))
22
+ data.append((context_two, target))
23
+
24
+ return data
25
+
26
+ def build_vocab(self):
27
+ rows_of_artists = self._generate_rows()
28
+ our_vocab = build_vocab_from_iterator(
29
+ rows_of_artists, specials=["<unk>"], min_freq=1
30
+ )
31
+
32
+ return our_vocab
33
+
34
+ def _generate_rows(self):
35
+ with open(self.input_file, encoding="utf-8") as f:
36
+ reader = csv.reader(f)
37
+ for row in reader:
38
+ yield row
39
+
40
+
41
+ class CBOW(torch.nn.Module):
42
+ def __init__(self, vocab):
43
+ super(CBOW, self).__init__()
44
+ self.num_epochs = 3
45
+ self.context_size = 1 # 1 word to the left, 1 to the right
46
+ self.embedding_dim = 100 # embedding vector size
47
+ self.learning_rate = 0.001
48
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ self.vocab = vocab
51
+ self.word_to_ix = self.vocab.get_stoi()
52
+ self.ix_to_word = self.vocab.get_itos()
53
+ self.vocab_list = list(self.vocab.get_stoi().keys())
54
+ self.vocab_size = len(self.vocab)
55
+
56
+ self.model = None
57
+
58
+ # out: 1 x embedding_dim
59
+ # initialize an Embedding matrix based on our inputs
60
+ self.embeddings = torch.nn.Embedding(self.vocab_size, self.embedding_dim)
61
+ self.linear1 = torch.nn.Linear(self.embedding_dim, 128)
62
+ self.activation_function1 = torch.nn.ReLU()
63
+
64
+ # out: 1 x vocab_size
65
+ self.linear2 = torch.nn.Linear(128, self.vocab_size)
66
+ self.activation_function2 = torch.nn.LogSoftmax(dim=-1)
67
+
68
+ def forward(self, inputs):
69
+ embeds = sum(self.embeddings(inputs)).view(1, -1)
70
+ out = self.linear1(embeds)
71
+ out = self.activation_function1(out)
72
+ out = self.linear2(out)
73
+ out = self.activation_function2(out)
74
+ return out
75
+
76
+ def get_word_emdedding(self, word):
77
+ word = torch.tensor([self.word_to_ix[word]])
78
+ # Embeddings lookup of a single word,
79
+ # once the Embeddings layer has been optimized
80
+ return self.embeddings(word).view(1, -1)
81
+
82
+
83
+ def make_context_vector(context, word_to_ix):
84
+ idxs = [word_to_ix[w] for w in context]
85
+ return torch.tensor(idxs, dtype=torch.long)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ artist_names = "data/artist-names-per-row.csv"
90
+ model_path = "data/cbow-model-weights"
91
+ text = TextPreProcessor(artist_names)
92
+ training_data = text.build_training_data()
93
+ vocab = text.build_vocab()
94
+ cbow = CBOW(vocab)
95
+
96
+ loss_function = torch.nn.NLLLoss()
97
+ optimizer = torch.optim.SGD(cbow.parameters(), lr=0.001)
98
+
99
+ # 50 to start with, no correct answer here
100
+ for epoch in range(50):
101
+ # we start tracking how accurate our intial words are
102
+ total_loss = 0
103
+
104
+ # for the x, y in the training data:
105
+ for context, target in training_data:
106
+ context_vector = make_context_vector(context, cbow.word_to_ix)
107
+
108
+ # we look at loss
109
+ log_probs = cbow(context_vector)
110
+
111
+ # we compare the loss from what the actual word is, related to the
112
+ # probaility of the words
113
+ total_loss += loss_function(
114
+ log_probs, torch.tensor([cbow.word_to_ix[target]])
115
+ )
116
+
117
+ # optimize at the end of each epoch
118
+ optimizer.zero_grad()
119
+ total_loss.backward()
120
+ optimizer.step()
121
+
122
+ # Log out some metrics to see if loss decreases
123
+ print("end of epoch {} | loss {:2.3f}".format(epoch, total_loss))
124
+ torch.save(cbow.state_dict(), model_path)
125
+ print("saved model!")