File size: 2,249 Bytes
eb21223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eec75b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import HfApi, create_repo

# モデルのONNXエクスポート関数
def convert_to_onnx_and_deploy(model_repo, input_text, hf_token, repo_name):
    try:
        # Hugging Faceトークンを設定
        os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
        
        # モデルとトークナイザーの読み込み
        tokenizer = AutoTokenizer.from_pretrained(model_repo)
        model = AutoModelForCausalLM.from_pretrained(model_repo)

        # 入力テキストをトークナイズ
        inputs = tokenizer(input_text, return_tensors="pt")

        # ONNXファイルの保存
        onnx_file = f"{repo_name}.onnx"
        torch.onnx.export(
            model,
            inputs['input_ids'],
            onnx_file,
            input_names=['input_ids'],
            output_names=['output'],
            dynamic_axes={'input_ids': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )

        # モデルをHugging Face Hubにデプロイ
        api = HfApi()
        create_repo(repo_name, private=True)  # プライベートリポジトリを作成
        api.upload_file(onnx_file, repo_id=repo_name)  # ONNXファイルをアップロード

        return f"ONNXモデルが作成され、リポジトリにデプロイされました: {repo_name}"
    except Exception as e:
        return str(e)

# Gradioインターフェース
iface = gr.Interface(
    fn=convert_to_onnx_and_deploy,
    inputs=[
        gr.Textbox(label="モデルリポジトリ(例: rinna/japanese-gpt2-medium)"),
        gr.Textbox(label="入力テキスト"),
        gr.Textbox(label="Hugging Faceトークン", type="password"),  # パスワード入力タイプ
        gr.Textbox(label="デプロイ先リポジトリ名")  # デプロイ先のリポジトリ名
    ],
    outputs="text",
    title="ONNX変換とモデルデプロイ機能",
    description="指定したHugging FaceのモデルリポジトリをONNX形式に変換し、デプロイします。"
)

# 使用するポート番号を指定してインターフェースを起動
iface.launch(server_port=7865)  # 7865ポートを指定