Spaces:
Running
Running
from __future__ import unicode_literals | |
import re | |
import unicodedata | |
import torch | |
import streamlit as st | |
import pandas as pd | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
import numpy as np | |
import scipy.spatial | |
from transformers import BertJapaneseTokenizer, BertModel | |
import pyminizip | |
def unicode_normalize(cls, s): | |
pt = re.compile("([{}]+)".format(cls)) | |
def norm(c): | |
return unicodedata.normalize("NFKC", c) if pt.match(c) else c | |
s = "".join(norm(x) for x in re.split(pt, s)) | |
s = re.sub("-", "-", s) | |
return s | |
def remove_extra_spaces(s): | |
s = re.sub("[ ]+", " ", s) | |
blocks = "".join( | |
( | |
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS | |
"\u3040-\u309F", # HIRAGANA | |
"\u30A0-\u30FF", # KATAKANA | |
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION | |
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS | |
) | |
) | |
basic_latin = "\u0000-\u007F" | |
def remove_space_between(cls1, cls2, s): | |
p = re.compile("([{}]) ([{}])".format(cls1, cls2)) | |
while p.search(s): | |
s = p.sub(r"\1\2", s) | |
return s | |
s = remove_space_between(blocks, blocks, s) | |
s = remove_space_between(blocks, basic_latin, s) | |
s = remove_space_between(basic_latin, blocks, s) | |
return s | |
def normalize_neologd(s): | |
s = s.strip() | |
s = unicode_normalize("0-9A-Za-z。-゚", s) | |
def maketrans(f, t): | |
return {ord(x): ord(y) for x, y in zip(f, t)} | |
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens | |
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus | |
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe) | |
s = s.translate( | |
maketrans( | |
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", | |
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", | |
) | |
) | |
s = remove_extra_spaces(s) | |
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」 | |
s = re.sub("[’]", "'", s) | |
s = re.sub("[”]", '"', s) | |
s = s.lower() | |
return s | |
def normalize_text(text): | |
return normalize_neologd(text) | |
def normalize_title(title): | |
title = title.strip() | |
match = re.match(r"^「([^」]+)」$", title) | |
if match: | |
title = match.group(1) | |
match = re.match(r"^POP素材「([^」]+)」$", title) | |
if match: | |
title = match.group(1) | |
title = re.sub(r"(の?(?:イラスト|イラストの|イラストト|イ子のラスト|イラス|イラスト文字|「イラスト文字」|イラストPOP文字|ペンキ文字|タイトル文字|イラスト・メッセージ|イラスト文字・バナー|キャラクター(たち)?|マーク|アイコン|シルエット|シルエット素材|フレーム(枠)|フレーム|フレーム素材|テンプレート|パターン|パターン素材|ライン素材|コーナー素材|リボン型バナー|評価スタンプ|背景素材))+(\s*([0-90-9]*|その[0-90-9]+))(です。)?", "", title) | |
title = normalize_text(title) | |
if title.strip() == "": | |
raise ValueError(title) | |
return title | |
class SentenceBertJapanese: | |
def __init__(self, model_name_or_path, device=None): | |
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path) | |
self.model = BertModel.from_pretrained(model_name_or_path) | |
self.model.eval() | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.device = torch.device(device) | |
self.model.to(device) | |
def _mean_pooling(self, model_output, attention_mask): | |
token_embeddings = model_output[ | |
0 | |
] # First element of model_output contains all token embeddings | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def encode(self, sentences, batch_size=8): | |
all_embeddings = [] | |
iterator = range(0, len(sentences), batch_size) | |
for batch_idx in iterator: | |
batch = sentences[batch_idx : batch_idx + batch_size] | |
encoded_input = self.tokenizer.batch_encode_plus( | |
batch, padding="longest", truncation=True, return_tensors="pt" | |
).to(self.device) | |
model_output = self.model(**encoded_input) | |
sentence_embeddings = self._mean_pooling( | |
model_output, encoded_input["attention_mask"] | |
).to("cpu") | |
all_embeddings.extend(sentence_embeddings) | |
# return torch.stack(all_embeddings).numpy() | |
return torch.stack(all_embeddings) | |
st.title("いらすと検索") | |
description_text = st.empty() | |
if "model" not in st.session_state: | |
description_text.text("...モデル読み込み中...") | |
model = SentenceBertJapanese("sonoisa/sentence-bert-base-ja-mean-tokens") | |
st.session_state.model = model | |
pyminizip.uncompress( | |
"irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1 | |
) | |
df = pq.read_table("irasuto_items_20210224.parquet").to_pandas() | |
st.session_state.df = df | |
sentence_vectors = np.stack(df["sentence_vector"]) | |
st.session_state.sentence_vectors = sentence_vectors | |
model = st.session_state.model | |
df = st.session_state.df | |
sentence_vectors = st.session_state.sentence_vectors | |
description_text.text("説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。") | |
prev_query = "" | |
query_input = st.text_input(label="説明文", value="") | |
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100) | |
search_buttion = st.button("検索") | |
if search_buttion or prev_query != query_input: | |
query = normalize_text(query_input) | |
prev_query = query_input | |
query_embedding = model.encode([query]).numpy() | |
distances = scipy.spatial.distance.cdist( | |
query_embedding, sentence_vectors, metric="cosine" | |
)[0] | |
results = zip(range(len(distances)), distances) | |
results = sorted(results, key=lambda x: x[1]) | |
for i, (idx, distance) in enumerate(results[0:closest_n]): | |
md_content = "" | |
page_url = df.iloc[idx]["page"] | |
for img_url in df.iloc[idx]["images"]: | |
md_content += f'<a href="{page_url}" target="_blank" rel="noopener noreferrer"><img src="{img_url}" width="100"></a>' | |
md_content += f'\n[{distance / 2:.4f}: {df.iloc[idx]["description"]}]({page_url})' | |
st.markdown(md_content, unsafe_allow_html=True) | |