Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
# モデルをダウンロードするディレクトリ | |
model_dir = "models/Miwa-Keita/zenz-v1-checkpoints" | |
# 不要なファイルを除外し、特定のファイルのみダウンロード | |
snapshot_download( | |
repo_id="Miwa-Keita/zenz-v1-checkpoints", | |
local_dir=model_dir, | |
allow_patterns=["*.bin", "*.json", "*.txt", "*.model"], # 必要なファイルだけ取得 | |
ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視 | |
) | |
# モデルとトークナイザーのロード(GPT-2 アーキテクチャ) | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float32) | |
# 入力を調整する関数 | |
def preprocess_input(user_input): | |
prefix = "\uEE00" # 前に付与する文字列 | |
suffix = "\uEE01" # 後ろに付与する文字列 | |
processed_input = prefix + user_input + suffix | |
return processed_input | |
# 出力を調整する関数 | |
def postprocess_output(model_output): | |
suffix = "\uEE01" | |
# \uEE01の後の部分を抽出 | |
if suffix in model_output: | |
return model_output.split(suffix)[1] | |
return model_output | |
# 変換関数 | |
def generate_text(user_input): | |
processed_input = preprocess_input(user_input) | |
# テキストをトークン化 | |
inputs = tokenizer(processed_input, return_tensors="pt") | |
# モデルで生成 | |
outputs = model.generate(**inputs, max_length=100) | |
# 出力のデコード | |
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# 出力の整形 | |
return postprocess_output(decoded_output) | |
# Gradio インターフェース | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=gr.Textbox(label="変換する文字列(カタカナ)"), | |
outputs=gr.Textbox(label="変換結果"), | |
title="ニューラルかな漢字変換モデル zenz-v1 のデモ", | |
description="変換したい文字列をカタカナを入力してください" | |
) | |
# ローンチ | |
iface.launch() |