Spaces:
Running
Running
from __future__ import unicode_literals | |
import re | |
import unicodedata | |
import torch | |
import streamlit as st | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
def load_model(): | |
# 学習済みモデルをHugging Face model hubからダウンロードする | |
model_dir_name = "sonoisa/t5-qiita-title-generation" | |
# トークナイザー(SentencePiece) | |
tokenizer = T5Tokenizer.from_pretrained(model_dir_name, is_fast=True) | |
# 学習済みモデル | |
trained_model = T5ForConditionalGeneration.from_pretrained(model_dir_name) | |
# GPUの利用有無 | |
USE_GPU = torch.cuda.is_available() | |
if USE_GPU: | |
trained_model.cuda() | |
return trained_model, tokenizer | |
def unicode_normalize(cls, s): | |
pt = re.compile("([{}]+)".format(cls)) | |
def norm(c): | |
return unicodedata.normalize("NFKC", c) if pt.match(c) else c | |
s = "".join(norm(x) for x in re.split(pt, s)) | |
s = re.sub("-", "-", s) | |
return s | |
def remove_extra_spaces(s): | |
s = re.sub("[ ]+", " ", s) | |
blocks = "".join( | |
( | |
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS | |
"\u3040-\u309F", # HIRAGANA | |
"\u30A0-\u30FF", # KATAKANA | |
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION | |
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS | |
) | |
) | |
basic_latin = "\u0000-\u007F" | |
def remove_space_between(cls1, cls2, s): | |
p = re.compile("([{}]) ([{}])".format(cls1, cls2)) | |
while p.search(s): | |
s = p.sub(r"\1\2", s) | |
return s | |
s = remove_space_between(blocks, blocks, s) | |
s = remove_space_between(blocks, basic_latin, s) | |
s = remove_space_between(basic_latin, blocks, s) | |
return s | |
def normalize_neologd(s): | |
s = s.strip() | |
s = unicode_normalize("0-9A-Za-z。-゚", s) | |
def maketrans(f, t): | |
return {ord(x): ord(y) for x, y in zip(f, t)} | |
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens | |
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus | |
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe) | |
s = s.translate( | |
maketrans( | |
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", | |
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", | |
) | |
) | |
s = remove_extra_spaces(s) | |
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」 | |
s = re.sub("[’]", "'", s) | |
s = re.sub("[”]", '"', s) | |
return s | |
CODE_PATTERN = re.compile(r"```.*?```", re.MULTILINE | re.DOTALL) | |
LINK_PATTERN = re.compile(r"!?\[([^\]\)]+)\]\([^\)]+\)") | |
IMG_PATTERN = re.compile(r"<img[^>]*>") | |
URL_PATTERN = re.compile(r"(http|ftp)s?://[^\s]+") | |
NEWLINES_PATTERN = re.compile(r"(\s*\n\s*)+") | |
def clean_markdown(markdown_text): | |
markdown_text = CODE_PATTERN.sub(r"", markdown_text) | |
markdown_text = LINK_PATTERN.sub(r"\1", markdown_text) | |
markdown_text = IMG_PATTERN.sub(r"", markdown_text) | |
markdown_text = URL_PATTERN.sub(r"", markdown_text) | |
markdown_text = NEWLINES_PATTERN.sub(r"\n", markdown_text) | |
markdown_text = markdown_text.replace("`", "") | |
return markdown_text | |
def normalize_text(markdown_text): | |
markdown_text = clean_markdown(markdown_text) | |
markdown_text = markdown_text.replace("\t", " ") | |
markdown_text = normalize_neologd(markdown_text).lower() | |
markdown_text = markdown_text.replace("\n", " ") | |
return markdown_text | |
def preprocess_qiita_body(markdown_text): | |
return "body: " + normalize_text(markdown_text)[:4000] | |
def postprocess_title(title): | |
return re.sub(r"^title: ", "", title) | |
st.title("Qiita記事タイトル案生成") | |
description_text = st.empty() | |
if "trained_model" not in st.session_state: | |
description_text.text("...モデル読み込み中...") | |
trained_model, tokenizer = load_model() | |
trained_model.eval() | |
st.session_state.trained_model = trained_model | |
st.session_state.tokenizer = tokenizer | |
trained_model = st.session_state.trained_model | |
tokenizer = st.session_state.tokenizer | |
# GPUの利用有無 | |
USE_GPU = torch.cuda.is_available() | |
description_text.text("記事の本文をコピペ入力して、タイトル生成ボタンを押すと、タイトル案が10個生成されます。\nGPUが使えないため生成に数十秒かかります。") | |
qiita_body = st.text_area(label="記事の本文", value="", height=300, max_chars=4000) | |
answer = st.button("タイトル生成") | |
if answer: | |
title_fieids = st.empty() | |
title_fieids.markdown("...生成中...") | |
MAX_SOURCE_LENGTH = 512 # 入力される記事本文の最大トークン数 | |
MAX_TARGET_LENGTH = 64 # 生成されるタイトルの最大トークン数 | |
# 前処理とトークナイズを行う | |
inputs = [preprocess_qiita_body(qiita_body)] | |
batch = tokenizer.batch_encode_plus( | |
inputs, | |
max_length=MAX_SOURCE_LENGTH, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
) | |
input_ids = batch["input_ids"] | |
input_mask = batch["attention_mask"] | |
if USE_GPU: | |
input_ids = input_ids.cuda() | |
input_mask = input_mask.cuda() | |
# 生成処理を行う | |
outputs = trained_model.generate( | |
input_ids=input_ids, | |
attention_mask=input_mask, | |
max_length=MAX_TARGET_LENGTH, | |
return_dict_in_generate=True, | |
output_scores=True, | |
temperature=1.0, # 生成にランダム性を入れる温度パラメータ | |
num_beams=10, # ビームサーチの探索幅 | |
diversity_penalty=1.0, # 生成結果の多様性を生み出すためのペナルティ | |
num_beam_groups=10, # ビームサーチのグループ数 | |
num_return_sequences=10, # 生成する文の数 | |
repetition_penalty=1.5, # 同じ文の繰り返し(モード崩壊)へのペナルティ | |
) | |
# 生成されたトークン列を文字列に変換する | |
generated_titles = [ | |
tokenizer.decode( | |
ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
for ids in outputs.sequences | |
] | |
# 生成されたタイトルを表示する | |
titles = "## タイトル案:\n\n" | |
for i, title in enumerate(generated_titles): | |
titles += f"1. {postprocess_title(title)}\n" | |
title_fieids.markdown(titles) | |