Spaces:
Sleeping
Sleeping
File size: 6,516 Bytes
aebfdca ad409df aebfdca 0adfb1d aebfdca ad409df 0adfb1d aebfdca f5826d0 c1987c1 aebfdca f5826d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from __future__ import unicode_literals
import re
import unicodedata
import torch
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
def load_model():
# 学習済みモデルをHugging Face model hubからダウンロードする
model_dir_name = "sonoisa/t5-qiita-title-generation"
# トークナイザー(SentencePiece)
tokenizer = T5Tokenizer.from_pretrained(model_dir_name, is_fast=True)
# 学習済みモデル
trained_model = T5ForConditionalGeneration.from_pretrained(model_dir_name)
# GPUの利用有無
USE_GPU = torch.cuda.is_available()
if USE_GPU:
trained_model.cuda()
return trained_model, tokenizer
def unicode_normalize(cls, s):
pt = re.compile("([{}]+)".format(cls))
def norm(c):
return unicodedata.normalize("NFKC", c) if pt.match(c) else c
s = "".join(norm(x) for x in re.split(pt, s))
s = re.sub("-", "-", s)
return s
def remove_extra_spaces(s):
s = re.sub("[ ]+", " ", s)
blocks = "".join(
(
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS
"\u3040-\u309F", # HIRAGANA
"\u30A0-\u30FF", # KATAKANA
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS
)
)
basic_latin = "\u0000-\u007F"
def remove_space_between(cls1, cls2, s):
p = re.compile("([{}]) ([{}])".format(cls1, cls2))
while p.search(s):
s = p.sub(r"\1\2", s)
return s
s = remove_space_between(blocks, blocks, s)
s = remove_space_between(blocks, basic_latin, s)
s = remove_space_between(basic_latin, blocks, s)
return s
def normalize_neologd(s):
s = s.strip()
s = unicode_normalize("0-9A-Za-z。-゚", s)
def maketrans(f, t):
return {ord(x): ord(y) for x, y in zip(f, t)}
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe)
s = s.translate(
maketrans(
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」",
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」",
)
)
s = remove_extra_spaces(s)
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」
s = re.sub("[’]", "'", s)
s = re.sub("[”]", '"', s)
return s
CODE_PATTERN = re.compile(r"```.*?```", re.MULTILINE | re.DOTALL)
LINK_PATTERN = re.compile(r"!?\[([^\]\)]+)\]\([^\)]+\)")
IMG_PATTERN = re.compile(r"<img[^>]*>")
URL_PATTERN = re.compile(r"(http|ftp)s?://[^\s]+")
NEWLINES_PATTERN = re.compile(r"(\s*\n\s*)+")
def clean_markdown(markdown_text):
markdown_text = CODE_PATTERN.sub(r"", markdown_text)
markdown_text = LINK_PATTERN.sub(r"\1", markdown_text)
markdown_text = IMG_PATTERN.sub(r"", markdown_text)
markdown_text = URL_PATTERN.sub(r"", markdown_text)
markdown_text = NEWLINES_PATTERN.sub(r"\n", markdown_text)
markdown_text = markdown_text.replace("`", "")
return markdown_text
def normalize_text(markdown_text):
markdown_text = clean_markdown(markdown_text)
markdown_text = markdown_text.replace("\t", " ")
markdown_text = normalize_neologd(markdown_text).lower()
markdown_text = markdown_text.replace("\n", " ")
return markdown_text
def preprocess_qiita_body(markdown_text):
return "body: " + normalize_text(markdown_text)[:4000]
def postprocess_title(title):
return re.sub(r"^title: ", "", title)
st.title("Qiita記事タイトル案生成")
description_text = st.empty()
if "trained_model" not in st.session_state:
description_text.text("...モデル読み込み中...")
trained_model, tokenizer = load_model()
trained_model.eval()
st.session_state.trained_model = trained_model
st.session_state.tokenizer = tokenizer
trained_model = st.session_state.trained_model
tokenizer = st.session_state.tokenizer
# GPUの利用有無
USE_GPU = torch.cuda.is_available()
description_text.text("記事の本文をコピペ入力して、タイトル生成ボタンを押すと、タイトル案が10個生成されます。\nGPUが使えないため生成に数十秒かかります。")
qiita_body = st.text_area(label="記事の本文", value="", height=300, max_chars=4000)
answer = st.button("タイトル生成")
if answer:
title_fieids = st.empty()
title_fieids.markdown("...生成中...")
MAX_SOURCE_LENGTH = 512 # 入力される記事本文の最大トークン数
MAX_TARGET_LENGTH = 64 # 生成されるタイトルの最大トークン数
# 前処理とトークナイズを行う
inputs = [preprocess_qiita_body(qiita_body)]
batch = tokenizer.batch_encode_plus(
inputs,
max_length=MAX_SOURCE_LENGTH,
truncation=True,
padding="longest",
return_tensors="pt",
)
input_ids = batch["input_ids"]
input_mask = batch["attention_mask"]
if USE_GPU:
input_ids = input_ids.cuda()
input_mask = input_mask.cuda()
# 生成処理を行う
outputs = trained_model.generate(
input_ids=input_ids,
attention_mask=input_mask,
max_length=MAX_TARGET_LENGTH,
return_dict_in_generate=True,
output_scores=True,
temperature=1.0, # 生成にランダム性を入れる温度パラメータ
num_beams=10, # ビームサーチの探索幅
diversity_penalty=1.0, # 生成結果の多様性を生み出すためのペナルティ
num_beam_groups=10, # ビームサーチのグループ数
num_return_sequences=10, # 生成する文の数
repetition_penalty=1.5, # 同じ文の繰り返し(モード崩壊)へのペナルティ
)
# 生成されたトークン列を文字列に変換する
generated_titles = [
tokenizer.decode(
ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for ids in outputs.sequences
]
# 生成されたタイトルを表示する
titles = "## タイトル案:\n\n"
for i, title in enumerate(generated_titles):
titles += f"1. {postprocess_title(title)}\n"
title_fieids.markdown(titles)
|