File size: 5,018 Bytes
ea7e714
 
 
 
 
 
 
 
 
 
 
 
 
61938d3
ea7e714
61938d3
ea7e714
 
 
 
 
 
 
 
 
 
 
 
93770f8
ea7e714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61938d3
ea7e714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from transformers import AutoModelForCausalLM, T5Tokenizer
import csv, re, mojimoji

class Zmaker:

    #GPT2のモデル名
    gpt_model_name = "rinna/japanese-gpt2-medium"

    #文章の最大長
    min_len, max_len = 1, 128

    #予測時のパラメータ
    top_k, top_p = 40, 0.95 #top-k検索の閾値
    num_text = 1 #出力する文の数
    temp = 0.1
    repeat_ngram_size = 1

    #推論にCPU利用を強制するか
    use_cpu = True

    def __init__(self, ft_path = None):
        """コンストラクタ

          コンストラクタ。モデルをファイルから読み込む場合と,
          新規作成する場合で動作を分ける.

          Args:
              ft_path : ファインチューニングされたモデルのパス.
          Returns:
              なし
        """

        #モデルの設定
        self.__SetModel(ft_path)

        #モデルの状態をCPUかGPUかで切り替える
        if self.use_cpu: #CPUの利用を強制する場合の処理
            device = torch.device('cpu')
        else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(device)



    def __SetModel(self, ft_path = None):
      """GPT2の設定

        GPT2のTokenizerおよびモデルを設定する.
        ユーザー定義後と顔文字も語彙として認識されるように設定する.
        
        Args:
            ft_path : ファインチューニング済みのモデルを読み込む
                      何も指定しないとself.gpt_model_nameの事前学習モデルを
                      ネットからダウンロードする.
        Returns:
            なし
      """
      #GPT2のTokenizerのインスタンスを生成
      self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name)
      self.tokenizer.do_lower_case = True # due to some bug of tokenizer config loading

      #モデルの読み込み
      if ft_path is not None:
          self.model = AutoModelForCausalLM.from_pretrained(
              ft_path, #torch_dtype = torch.bfloat16
          )
      else:
          print("fine-tuned model was not found")
    
      #モデルをevalモードに
      self.model.eval()

    def __TextCleaning(self, texts):
        """テキストの前処理をする

          テキストの前処理を行う.具体的に行うこととしては...
          ・全角/半角スペースの除去
          ・半角数字/アルファベットの全角化
        """
        #半角スペース,タブ,改行改ページを削除
        texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]

        #半角/全角を変換
        texts = [mojimoji.han_to_zen(t) for t in texts]
        return texts
    

    def GenLetter(self, prompt):
      """怪文書の生成

        GPT2で怪文書を生成する.
        promptに続く文章を生成して出力する

        Args:
            prompt : 文章の先頭
        Retunrs:
            生成された文章のリスト
      """

      #テキストをクリーニング
      prompt_clean = [prompt]
      
      #文章をtokenizerでエンコード
      x = self.tokenizer.encode(
          prompt_clean[0], return_tensors="pt", 
          add_special_tokens=False
      )
      
      #デバイスの選択
      if self.use_cpu: #CPUの利用を強制する場合の処理
          device = torch.device('cpu')
      else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
          device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
      x = x.to(device)

      #gpt2による推論
      with torch.no_grad():
          y = self.model.generate(
              x, #入力
              min_length=self.min_len,  # 文章の最小長
              max_length=self.max_len,  # 文章の最大長
              do_sample=True,   # 次の単語を確率で選ぶ
              top_k=self.top_k, # Top-Kサンプリング
              top_p=self.top_p,  # Top-pサンプリング
              temperature=self.temp,  # 確率分布の調整
              no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか
              num_return_sequences=self.num_text,  # 生成する文章の数
              pad_token_id=self.tokenizer.pad_token_id,  # パディングのトークンID
              bos_token_id=self.tokenizer.bos_token_id,  # テキスト先頭のトークンID
              eos_token_id=self.tokenizer.eos_token_id,  # テキスト終端のトークンID
              early_stopping=True
          )
      
      # 特殊トークンをスキップして推論結果を文章にデコード
      res = self.tokenizer.batch_decode(y, skip_special_tokens=True)
      return res