File size: 5,091 Bytes
dcd398a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
import itertools
def is_hiragana_or_katakana(s):
for char in s:
if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー":
return False
return True
def add_dakuten_handakuten(query, string_type):
def convert_to_hiragana(s):
"""与えられた文字列を平仮名に変換する"""
result = []
for char in s:
if 'ァ' <= char <= 'ヶ': # 片仮名を平仮名に変換
result.append(chr(ord(char) - 96))
else:
result.append(char)
return ''.join(result)
def convert_to_katakana(s):
"""与えられた文字列を片仮名に変換する"""
result = []
for char in s:
if 'ぁ' <= char <= 'ゖ': # 平仮名を片仮名に変換
result.append(chr(ord(char) + 96))
else:
result.append(char)
return ''.join(result)
if string_type == "hiragana":
s = convert_to_hiragana(query)
dakuon_map = {
'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご',
'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ',
'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど',
'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ'
}
handakuon_map = {
'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ'
}
elif string_type == "katakana":
s = convert_to_katakana(query)
dakuon_map = {
'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ',
'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ',
'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド',
'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ',
'ウ': 'ヴ'
}
handakuon_map = {
'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ'
}
# 文字ごとに元の文字と濁音・半濁音をリストにする
options = []
for char in s:
temp = [char]
if char in dakuon_map:
temp.append(dakuon_map[char])
if char in handakuon_map:
temp.append(handakuon_map[char])
options.append(temp)
# 全ての組み合わせを生成
candidates = list(itertools.product(*options))
return candidates
def add_dashes(s):
if not s:
return ['']
# 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得
substr_patterns = add_dashes(s[1:])
# 現在の文字を含めたパターンを生成
result = []
for pattern in substr_patterns:
result.append(s[0] + pattern) # そのまま連結
result.append(s[0] + 'ー' + pattern) # 「ー」を挿入して連結
return result
def compute_losses(candidates, model, tokenizer):
inputs = tokenizer(candidates, return_tensors="pt", padding=True)
inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100)
inputs = inputs.to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction="none")
losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
losses_seq = losses_flat.view(shift_labels.shape)
mask_labels = shift_labels != tokenizer.pad_token_id
losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1)
return losses
def search_candidates(query, query_candidates, model, tokenizer, top_k=100):
old_query = query[:-1]
if old_query not in query_candidates:
old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k)
else:
old_candidates, _ = query_candidates[old_query]
string = query[-1]
candidates = []
for string_type in ["hiragana", "katakana"]:
candidates_ = add_dakuten_handakuten(string, string_type=string_type)
for candidate_ in candidates_:
candidates += add_dashes(candidate_)
combinations = itertools.product(old_candidates, candidates)
new_candidates = [''.join(pair) for pair in combinations]
losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer)
sorted_items = torch.sort(losses)
sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()]
topk_candidates = sorted_candidates[:top_k].tolist()
topk_losses = sorted_items.values[:top_k].cpu().tolist()
query_candidates[query] = (topk_candidates, topk_losses)
return topk_candidates, topk_losses |