Spaces:
Runtime error
Runtime error
File size: 2,204 Bytes
06a4fa8 5e0e168 06a4fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from pprint import pprint
import gradio as gr
import torch
from safetensors import safe_open
from transformers import BertTokenizer
from utils.ClassifierModel import ClassifierModel
def _classify_text(text, model, device, tokenizer, max_length=20):
"""
テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
"""
# テキストをトークナイズし、PyTorchのテンソルに変換
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
pprint(f"inputs: {inputs}")
# モデルの推論
model.eval()
with torch.no_grad():
outputs = model(
inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
)
pprint(f"outputs: {outputs}")
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# 確率の取得
chiikawa_prob = probabilities[0][0].item()
yonezu_prob = probabilities[0][1].item()
return chiikawa_prob, yonezu_prob
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
pprint(f"device: {device}")
model_save_path = "models/model.safetensors"
tensors = {}
with safe_open(model_save_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
inference_model: torch.nn.Module = ClassifierModel().to(device)
inference_model.load_state_dict(tensors)
tokenizer = BertTokenizer.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking"
)
def classify_text(text):
chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
return {"ちいかわ": chii_prob, "米津玄師": yone_prob}
demo = gr.Interface(
fn=classify_text,
inputs="textbox",
outputs="label",
examples=[
"守りたいんだ",
"どうしてどうしてどうして",
"そこから見ていてね",
"ヤンパパン"
],
)
demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀
|