File size: 2,357 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/**
 * @license
 * Copyright Big Vision Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common';

interface Candidate {
  piece: string;
  pos: number;
  score: number;
}

const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score;

function processInput(str: string): string {
  const normalized = str.normalize('NFKC');
  return normalized.length > 0 ?
    TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) :
    normalized;
}

/**
 * Sentencepiece tokenizer implementing the BPE algorithm.
 */
export class Tokenizer implements TokenizerInterface {

  // piece -> [score, index]
  private readonly map: Map<string, [number, number]>;

  constructor(vocabulary: Vocabulary) {
    this.map = new Map<string, [number, number]>();
    vocabulary.forEach(([piece, score], idx) => {
      if (this.map.has(piece)) {
        throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`);
      }
      this.map.set(piece, [score, idx]);
    });
  }

  encode(input: string): number[] {
    const processed: string = processInput(input);
    let pieces: string[] = stringToChars(processed);

    while (true) {
      const candidates: Candidate[] = [];
      for (let i = 0; i < pieces.length - 1; i++) {
        const fused = pieces[i] + pieces[i + 1];
        const el = this.map.get(fused);
        if (el) {
          candidates.push({ piece: fused, pos: i, score: el[0] });
        }
      }
      if (candidates.length === 0) {
        break;
      }
      candidates.sort(scoreDesc);
      const best = candidates[0];
      pieces = [
        ...pieces.slice(0, best.pos),
        best.piece,
        ...pieces.slice(best.pos + 2)
      ];
    }

    return pieces.map(piece => this.map.get(piece)![1]);
  }
}