Nischay103 commited on
Commit
9c56e39
1 Parent(s): f3e0ff3

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +108 -108
tokenizer.py CHANGED
@@ -1,108 +1,108 @@
1
- import torch
2
- from abc import ABC, abstractmethod
3
- from typing import List, Optional, Tuple
4
- from torch import Tensor
5
- from torch.nn.utils.rnn import pad_sequence
6
-
7
-
8
- class BaseTokenizer(ABC):
9
-
10
- def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
11
- self._itos = specials_first + tuple(charset + '[UNK]') + specials_last
12
- self._stoi = {s: i for i, s in enumerate(self._itos)}
13
-
14
- def __len__(self):
15
- return len(self._itos)
16
-
17
- def _tok2ids(self, tokens: str) -> List[int]:
18
- return [self._stoi[s] for s in tokens]
19
-
20
- def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
21
- tokens = [self._itos[i] for i in token_ids]
22
- return ''.join(tokens) if join else tokens
23
-
24
- @abstractmethod
25
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
26
- raise NotImplementedError
27
-
28
- @abstractmethod
29
- def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
30
- """Internal method which performs the necessary filtering prior to decoding."""
31
- raise NotImplementedError
32
-
33
- def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
34
- if beam_width > 1:
35
- return self.beam_search_decode(token_dists, beam_width, raw)
36
- else:
37
- return self.greedy_decode(token_dists, raw)
38
-
39
- def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
40
- batch_tokens = []
41
- batch_probs = []
42
- for dist in token_dists:
43
- probs, ids = dist.max(-1) # greedy selection
44
- if not raw:
45
- probs, ids = self._filter(probs, ids)
46
- tokens = self._ids2tok(ids, not raw)
47
- batch_tokens.append(tokens)
48
- batch_probs.append(probs)
49
- return batch_tokens, batch_probs
50
-
51
- def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]:
52
- batch_tokens = []
53
- batch_probs = []
54
-
55
- for dist in token_dists:
56
- sequences = [([], 1.0)]
57
- for step_dist in dist:
58
- all_candidates = []
59
- for seq, score in sequences:
60
- top_probs, top_ids = step_dist.topk(beam_width)
61
- for i in range(beam_width):
62
- candidate = (seq + [top_ids[i].item()],
63
- score * top_probs[i].item())
64
- all_candidates.append(candidate)
65
- ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
66
- sequences = ordered[:beam_width]
67
-
68
- best_sequence, best_score = sequences[0]
69
- if not raw:
70
- best_score_tensor = torch.tensor([best_score])
71
- best_sequence_tensor = torch.tensor(best_sequence)
72
- best_score_tensor, best_sequence = self._filter(
73
- best_score_tensor, best_sequence_tensor)
74
- best_score = best_score_tensor.item()
75
- tokens = self._ids2tok(best_sequence, not raw)
76
- batch_tokens.append(tokens)
77
- batch_probs.append(best_score)
78
-
79
- return batch_tokens, batch_probs
80
-
81
-
82
- class Tokenizer(BaseTokenizer):
83
- BOS = '[B]'
84
- EOS = '[E]'
85
- PAD = '[P]'
86
-
87
- def __init__(self, charset: str) -> None:
88
- specials_first = (self.EOS,)
89
- specials_last = (self.BOS, self.PAD)
90
- super().__init__(charset, specials_first, specials_last)
91
- self.eos_id, self.bos_id, self.pad_id = [
92
- self._stoi[s] for s in specials_first + specials_last]
93
-
94
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
95
- batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
96
- for y in labels]
97
- return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
98
-
99
- def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
100
- ids = ids.tolist()
101
- try:
102
- eos_idx = ids.index(self.eos_id)
103
- except ValueError:
104
- eos_idx = len(ids)
105
- # Truncate after EOS
106
- ids = ids[:eos_idx]
107
- probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
108
- return probs, ids
 
1
+ import torch
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Optional, Tuple
4
+ from torch import Tensor
5
+ from torch.nn.utils.rnn import pad_sequence
6
+
7
+
8
+ class BaseTokenizer(ABC):
9
+
10
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
11
+ self._itos = specials_first + tuple(charset + '[UNK]') + specials_last
12
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
13
+
14
+ def __len__(self):
15
+ return len(self._itos)
16
+
17
+ def _tok2ids(self, tokens: str) -> List[int]:
18
+ return [self._stoi[s] for s in tokens]
19
+
20
+ def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
21
+ tokens = [self._itos[i] for i in token_ids]
22
+ return ''.join(tokens) if join else tokens
23
+
24
+ @abstractmethod
25
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
26
+ raise NotImplementedError
27
+
28
+ @abstractmethod
29
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
30
+ """Internal method which performs the necessary filtering prior to decoding."""
31
+ raise NotImplementedError
32
+
33
+ def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
34
+ if beam_width > 1:
35
+ return self.beam_search_decode(token_dists, beam_width, raw)
36
+ else:
37
+ return self.greedy_decode(token_dists, raw)
38
+
39
+ def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
40
+ batch_tokens = []
41
+ batch_probs = []
42
+ for dist in token_dists:
43
+ probs, ids = dist.max(-1)
44
+ if not raw:
45
+ probs, ids = self._filter(probs, ids)
46
+ tokens = self._ids2tok(ids, not raw)
47
+ batch_tokens.append(tokens)
48
+ batch_probs.append(probs)
49
+ return batch_tokens, batch_probs
50
+
51
+ def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]:
52
+ batch_tokens = []
53
+ batch_probs = []
54
+
55
+ for dist in token_dists:
56
+ sequences = [([], 1.0)]
57
+ for step_dist in dist:
58
+ all_candidates = []
59
+ for seq, score in sequences:
60
+ top_probs, top_ids = step_dist.topk(beam_width)
61
+ for i in range(beam_width):
62
+ candidate = (seq + [top_ids[i].item()],
63
+ score * top_probs[i].item())
64
+ all_candidates.append(candidate)
65
+ ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
66
+ sequences = ordered[:beam_width]
67
+
68
+ best_sequence, best_score = sequences[0]
69
+ if not raw:
70
+ best_score_tensor = torch.tensor([best_score])
71
+ best_sequence_tensor = torch.tensor(best_sequence)
72
+ best_score_tensor, best_sequence = self._filter(
73
+ best_score_tensor, best_sequence_tensor)
74
+ best_score = best_score_tensor.item()
75
+ tokens = self._ids2tok(best_sequence, not raw)
76
+ batch_tokens.append(tokens)
77
+ batch_probs.append(best_score)
78
+
79
+ return batch_tokens, batch_probs
80
+
81
+
82
+ class Tokenizer(BaseTokenizer):
83
+ BOS = '[B]'
84
+ EOS = '[E]'
85
+ PAD = '[P]'
86
+
87
+ def __init__(self, charset: str) -> None:
88
+ specials_first = (self.EOS,)
89
+ specials_last = (self.BOS, self.PAD)
90
+ super().__init__(charset, specials_first, specials_last)
91
+ self.eos_id, self.bos_id, self.pad_id = [
92
+ self._stoi[s] for s in specials_first + specials_last]
93
+
94
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
95
+ batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
96
+ for y in labels]
97
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
98
+
99
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
100
+ ids = ids.tolist()
101
+ try:
102
+ eos_idx = ids.index(self.eos_id)
103
+ except ValueError:
104
+ eos_idx = len(ids)
105
+ # Truncate after EOS
106
+ ids = ids[:eos_idx]
107
+ probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
108
+ return probs, ids