|
import torch |
|
from transformers import AutoModelForCausalLM, T5Tokenizer |
|
import csv, re, mojimoji |
|
|
|
class Zmaker: |
|
|
|
|
|
gpt_model_name = "rinna/japanese-gpt2-medium" |
|
|
|
|
|
min_len, max_len = 1, 128 |
|
|
|
|
|
top_k, top_p = 40, 0.95 |
|
num_text = 1 |
|
temp = 0.1 |
|
repeat_ngram_size = 1 |
|
|
|
|
|
use_cpu = True |
|
|
|
def __init__(self, ft_path = None): |
|
"""コンストラクタ |
|
|
|
コンストラクタ。モデルをファイルから読み込む場合と, |
|
新規作成する場合で動作を分ける. |
|
|
|
Args: |
|
ft_path : ファインチューニングされたモデルのパス. |
|
Returns: |
|
なし |
|
""" |
|
|
|
|
|
self.__SetModel(ft_path) |
|
|
|
|
|
if self.use_cpu: |
|
device = torch.device('cpu') |
|
else: |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
self.model.to(device) |
|
|
|
|
|
|
|
def __SetModel(self, ft_path = None): |
|
"""GPT2の設定 |
|
|
|
GPT2のTokenizerおよびモデルを設定する. |
|
ユーザー定義後と顔文字も語彙として認識されるように設定する. |
|
|
|
Args: |
|
ft_path : ファインチューニング済みのモデルを読み込む |
|
何も指定しないとself.gpt_model_nameの事前学習モデルを |
|
ネットからダウンロードする. |
|
Returns: |
|
なし |
|
""" |
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name) |
|
self.tokenizer.do_lower_case = True |
|
|
|
|
|
if ft_path is not None: |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
ft_path, |
|
) |
|
else: |
|
print("fine-tuned model was not found") |
|
|
|
|
|
self.model.eval() |
|
|
|
def __TextCleaning(self, texts): |
|
"""テキストの前処理をする |
|
|
|
テキストの前処理を行う.具体的に行うこととしては... |
|
・全角/半角スペースの除去 |
|
・半角数字/アルファベットの全角化 |
|
""" |
|
|
|
texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts] |
|
|
|
|
|
texts = [mojimoji.han_to_zen(t) for t in texts] |
|
return texts |
|
|
|
|
|
def GenLetter(self, prompt): |
|
"""怪文書の生成 |
|
|
|
GPT2で怪文書を生成する. |
|
promptに続く文章を生成して出力する |
|
|
|
Args: |
|
prompt : 文章の先頭 |
|
Retunrs: |
|
生成された文章のリスト |
|
""" |
|
|
|
|
|
prompt_clean = [prompt] |
|
|
|
|
|
x = self.tokenizer.encode( |
|
prompt_clean[0], return_tensors="pt", |
|
add_special_tokens=False |
|
) |
|
|
|
|
|
if self.use_cpu: |
|
device = torch.device('cpu') |
|
else: |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
x = x.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
y = self.model.generate( |
|
x, |
|
min_length=self.min_len, |
|
max_length=self.max_len, |
|
do_sample=True, |
|
top_k=self.top_k, |
|
top_p=self.top_p, |
|
temperature=self.temp, |
|
no_repeat_ngram_size = self.repeat_ngram_size, |
|
num_return_sequences=self.num_text, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
bos_token_id=self.tokenizer.bos_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
early_stopping=True |
|
) |
|
|
|
|
|
res = self.tokenizer.batch_decode(y, skip_special_tokens=True) |
|
return res |