Ayemos's picture
add mean surprisals as output
bf4b865
raw
history blame
3.17 kB
from typing import List, Tuple
import gradio as gr
import numpy as np
import torch
from transformers import AutoModelForCausalLM, T5Tokenizer
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
model.to(device)
def calculate_surprisals(
input_text: str, normalize_surprisals: bool = True
) -> Tuple[float, List[Tuple[str, float]]]:
input_tokens = [
token.replace("▁", "")
for token in tokenizer.tokenize(input_text)
if token != "▁"
]
input_ids = tokenizer.encode(
"<s>" + input_text, add_special_tokens=False, return_tensors="pt"
).to(device)
logits = model(input_ids)["logits"].squeeze(0)
surprisals = []
for i in range(logits.shape[0] - 1):
if input_ids[0][i + 1] == 9:
continue
logit = logits[i]
prob = torch.softmax(logit, dim=0)
neg_logprob = -torch.log(prob)
surprisals.append(neg_logprob[input_ids[0][i + 1]].item())
mean_surprisal = np.mean(surprisals)
if normalize_surprisals:
min_surprisal = np.min(surprisals)
max_surprisal = np.max(surprisals)
surprisals = [
(surprisal - min_surprisal) / (max_surprisal - min_surprisal)
for surprisal in surprisals
]
assert min(surprisals) >= 0
assert max(surprisals) <= 1
tokens2surprisal: List[Tuple[str, float]] = []
for token, surprisal in zip(input_tokens, surprisals):
tokens2surprisal.append((token, surprisal))
return mean_surprisal, tokens2surprisal
def highlight_token(token: str, score: float):
html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score)))
return '<span style="background-color: {}; color: black">{}</span>'.format(
html_color, token
)
def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
highlighted_text: str = ""
for token, score in tokens2scores:
highlighted_text += highlight_token(token, score)
highlighted_text += "<br><br>"
return highlighted_text
def main(input_text: str) -> Tuple[float, str]:
mean_surprisal, tokens2surprisal = calculate_surprisals(
input_text, normalize_surprisals=True
)
highlighted_text = create_highlighted_text(tokens2surprisal)
return round(mean_surprisal, 2), highlighted_text
if __name__ == "__main__":
demo = gr.Interface(
fn=main,
title="読みにくい箇所を検出するAI(デモ)",
description="テキストを入力すると、読みにくさに応じてハイライトされて出力されます。",
inputs=gr.inputs.Textbox(
lines=5, label="テキスト", placeholder="ここにテキストを入力してください。"
),
outputs=[
gr.Number(label="文全体の読みにくさ(サプライザル)"),
gr.outputs.HTML(label="トークン毎サプライザル"),
],
)
demo.launch()