File size: 5,091 Bytes
dcd398a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
import itertools

def is_hiragana_or_katakana(s):
    for char in s:
        if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー":
            return False
    return True

def add_dakuten_handakuten(query, string_type):
    def convert_to_hiragana(s):
        """与えられた文字列を平仮名に変換する"""
        result = []
        for char in s:
            if 'ァ' <= char <= 'ヶ':  # 片仮名を平仮名に変換
                result.append(chr(ord(char) - 96))
            else:
                result.append(char)
        return ''.join(result)

    def convert_to_katakana(s):
        """与えられた文字列を片仮名に変換する"""
        result = []
        for char in s:
            if 'ぁ' <= char <= 'ゖ':  # 平仮名を片仮名に変換
                result.append(chr(ord(char) + 96))
            else:
                result.append(char)
        return ''.join(result)

    if string_type == "hiragana":
        s = convert_to_hiragana(query)
        dakuon_map = {
            'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご',
            'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ',
            'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど',
            'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ'
        }
        handakuon_map = {
            'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ'
        }
    elif string_type == "katakana":
        s = convert_to_katakana(query)
        dakuon_map = {
            'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ',
            'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ',
            'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド',
            'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ',
            'ウ': 'ヴ'
        }
        handakuon_map = {
            'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ'
        }

    # 文字ごとに元の文字と濁音・半濁音をリストにする
    options = []
    for char in s:
        temp = [char]
        if char in dakuon_map:
            temp.append(dakuon_map[char])
        if char in handakuon_map:
            temp.append(handakuon_map[char])
        options.append(temp)

    # 全ての組み合わせを生成
    candidates = list(itertools.product(*options))
    return candidates

def add_dashes(s):
    if not s:
        return ['']

    # 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得
    substr_patterns = add_dashes(s[1:])

    # 現在の文字を含めたパターンを生成
    result = []
    for pattern in substr_patterns:
        result.append(s[0] + pattern)  # そのまま連結
        result.append(s[0] + 'ー' + pattern)  # 「ー」を挿入して連結

    return result

def compute_losses(candidates, model, tokenizer):
    inputs = tokenizer(candidates, return_tensors="pt", padding=True)
    inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100)
    inputs = inputs.to(model.device)

    with torch.no_grad():
        outputs = model(**inputs)

        logits = outputs.logits
        labels = inputs["labels"]

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = CrossEntropyLoss(reduction="none")

        losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
        losses_seq = losses_flat.view(shift_labels.shape)
        mask_labels = shift_labels != tokenizer.pad_token_id
        losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1)

    return losses

def search_candidates(query, query_candidates, model, tokenizer, top_k=100):
    old_query = query[:-1]
    if old_query not in query_candidates:
        old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k)
    else:
        old_candidates, _ = query_candidates[old_query]
        
    string = query[-1]
    candidates = []
    for string_type in ["hiragana", "katakana"]:
        candidates_ = add_dakuten_handakuten(string, string_type=string_type)
        for candidate_ in candidates_:
            candidates += add_dashes(candidate_)
        
    combinations = itertools.product(old_candidates, candidates)
    new_candidates = [''.join(pair) for pair in combinations]
    
    losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer)
    sorted_items = torch.sort(losses)
    sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()]
    topk_candidates = sorted_candidates[:top_k].tolist()
    topk_losses = sorted_items.values[:top_k].cpu().tolist()
    
    query_candidates[query] = (topk_candidates, topk_losses)
    return topk_candidates, topk_losses