File size: 4,358 Bytes
5c7acaf
 
 
2c2a43f
 
 
59e8362
60b8eef
75b8738
fb6d698
2c2a43f
5c7acaf
8c83344
 
5c7acaf
8c83344
 
 
2c2a43f
5c7acaf
 
 
 
 
 
 
 
 
2c2a43f
75b8738
 
 
 
 
 
 
 
 
 
2c2a43f
54af0d0
59e8362
a8981a6
5c7acaf
59e8362
 
 
 
b4e040e
75b8738
 
b4e040e
75b8738
 
 
b4e040e
 
54af0d0
 
2c2a43f
 
7d8e922
99b4947
7d8e922
 
2c2a43f
 
 
54af0d0
 
 
f060548
5c7acaf
54af0d0
 
 
 
60b8eef
 
 
 
 
 
2c2a43f
 
 
 
5c7acaf
 
 
 
 
 
 
 
 
 
 
 
2c2a43f
5c7acaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# 必要なライブラリのインストール
# !pip install gradio huggingface_hub requests transformers safetensors torch

import gradio as gr
import requests
import torch
from transformers import AutoModel, AutoConfig
from huggingface_hub import HfApi, login
import safetensors
import os

def convert_and_deploy(url, repo_id, hf_write_token):
    # Hugging Face Hubにログイン
    try:
        login(token=hf_write_token)
    except Exception as e:
        return f"Hugging Face Hubへのログインに失敗しました。エラー: {e}"
    
    # セーフテンソルファイルをダウンロード
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        file_path = "model.safetensors"
        with open(file_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
    except requests.exceptions.RequestException as e:
        return f"ファイルのダウンロードに失敗しました。エラー: {e}"
    
    # ファイルの存在を確認
    if not os.path.exists(file_path):
        return "ファイルが正しく保存されませんでした。"
    
    # ファイルの内容を確認
    with open(file_path, "rb") as f:
        content = f.read(100)  # 先頭100バイトを読み込む
        if not content:
            return "ファイルが空です。"
    
    # モデルを読み込み
    try:
        # モデルの構成を取得
        model_name = repo_id.split('/')[-1]  # モデル名を取得
        config = AutoConfig.from_pretrained(model_name, token=hf_write_token)
        
        # モデルを構成に基づいて初期化
        model = AutoModel.from_config(config, torch_dtype=torch.float16)
        
        # セーフテンソルファイルからモデルの状態を読み込み
        with safetensors.safe_open(file_path, framework="pt") as f:
            state_dict = {k: f.get_tensor(k) for k in f.keys()}
        
        # BF16からFP16に変換
        state_dict = {k: v.to(torch.float16) for k, v in state_dict.items()}
        
        # モデルの状態を設定
        model.load_state_dict(state_dict)
    except Exception as e:
        return f"モデルの読み込みに失敗しました。エラー: {e}"
    
    # モデルをfloat16形式で保存
    try:
        model.save_pretrained(f"{model_name}_float16", torch_dtype=torch.float16)
    except Exception as e:
        return f"モデルの保存に失敗しました。エラー: {e}"
    
    # モデルをHugging Faceにデプロイ
    api = HfApi()
    try:
        api.upload_folder(
            folder_path=f"{model_name}_float16",
            repo_id=repo_id,
            token=hf_write_token,
            path_in_repo=f"{model_name}_float16",
            create_remote_repo=True
        )
    except Exception as e:
        if "does not exist" in str(e):
            return f"指定したモデル名 '{repo_id}' はHugging Face Hubに存在しません。モデル名を確認してください。"
        elif "token" in str(e):
            return f"トークンが無効または権限がありません。正しいトークンを確認してください。"
        else:
            return f"モデルのデプロイに失敗しました。エラー: {e}"
    
    return "モデルをfloat16に変換し、Hugging Faceにデプロイしました。"

# Gradioインターフェースの作成
with gr.Blocks() as demo:
    gr.Markdown("# モデルの変換とデプロイ")

    download_url = gr.Textbox(label="セーフテンソルURL", placeholder="セーフテンソルファイルのダウンロードリンクを入力してください")
    hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Hugging Faceの書き込みトークンを入力してください", type="password")
    repo_id = gr.Textbox(label="Hugging Face リポジトリID (ユーザー名/モデル名)", placeholder="Hugging FaceのリポジトリIDを入力してください(例:ユーザー名/モデル名)")

    output = gr.Textbox(label="結果")

    upload_button = gr.Button("ダウンロードしてデプロイ")

    upload_button.click(convert_and_deploy, inputs=[download_url, repo_id, hf_write_token], outputs=output)

# アプリの実行
demo.launch()