File size: 2,204 Bytes
06a4fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e0e168
 
 
 
06a4fa8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from pprint import pprint

import gradio as gr
import torch

from safetensors import safe_open
from transformers import BertTokenizer

from utils.ClassifierModel import ClassifierModel


def _classify_text(text, model, device, tokenizer, max_length=20):
    """
    テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
    """

    # テキストをトークナイズし、PyTorchのテンソルに変換
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
    pprint(f"inputs: {inputs}")

    # モデルの推論
    model.eval()
    with torch.no_grad():
        outputs = model(
            inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
        )
        pprint(f"outputs: {outputs}")
        probabilities = torch.nn.functional.softmax(outputs, dim=1)

    # 確率の取得
    chiikawa_prob = probabilities[0][0].item()
    yonezu_prob = probabilities[0][1].item()

    return chiikawa_prob, yonezu_prob


is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
pprint(f"device: {device}")

model_save_path = "models/model.safetensors"
tensors = {}
with safe_open(model_save_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

inference_model: torch.nn.Module = ClassifierModel().to(device)
inference_model.load_state_dict(tensors)

tokenizer = BertTokenizer.from_pretrained(
    "cl-tohoku/bert-base-japanese-whole-word-masking"
)


def classify_text(text):
    chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
    return {"ちいかわ": chii_prob, "米津玄師": yone_prob}


demo = gr.Interface(
    fn=classify_text,
    inputs="textbox",
    outputs="label",
    examples=[
        "守りたいんだ",
        "どうしてどうしてどうして",
        "そこから見ていてね",
        "ヤンパパン"
    ],
)

demo.launch(share=True)  # Share your demo with just 1 extra parameter 🚀