Sakalti's picture
Update app.py
60b8eef verified
raw
history blame
4.36 kB
# 必要なライブラリのインストール
# !pip install gradio huggingface_hub requests transformers safetensors torch
import gradio as gr
import requests
import torch
from transformers import AutoModel, AutoConfig
from huggingface_hub import HfApi, login
import safetensors
import os
def convert_and_deploy(url, repo_id, hf_write_token):
# Hugging Face Hubにログイン
try:
login(token=hf_write_token)
except Exception as e:
return f"Hugging Face Hubへのログインに失敗しました。エラー: {e}"
# セーフテンソルファイルをダウンロード
try:
response = requests.get(url, stream=True)
response.raise_for_status()
file_path = "model.safetensors"
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
except requests.exceptions.RequestException as e:
return f"ファイルのダウンロードに失敗しました。エラー: {e}"
# ファイルの存在を確認
if not os.path.exists(file_path):
return "ファイルが正しく保存されませんでした。"
# ファイルの内容を確認
with open(file_path, "rb") as f:
content = f.read(100) # 先頭100バイトを読み込む
if not content:
return "ファイルが空です。"
# モデルを読み込み
try:
# モデルの構成を取得
model_name = repo_id.split('/')[-1] # モデル名を取得
config = AutoConfig.from_pretrained(model_name, token=hf_write_token)
# モデルを構成に基づいて初期化
model = AutoModel.from_config(config, torch_dtype=torch.float16)
# セーフテンソルファイルからモデルの状態を読み込み
with safetensors.safe_open(file_path, framework="pt") as f:
state_dict = {k: f.get_tensor(k) for k in f.keys()}
# BF16からFP16に変換
state_dict = {k: v.to(torch.float16) for k, v in state_dict.items()}
# モデルの状態を設定
model.load_state_dict(state_dict)
except Exception as e:
return f"モデルの読み込みに失敗しました。エラー: {e}"
# モデルをfloat16形式で保存
try:
model.save_pretrained(f"{model_name}_float16", torch_dtype=torch.float16)
except Exception as e:
return f"モデルの保存に失敗しました。エラー: {e}"
# モデルをHugging Faceにデプロイ
api = HfApi()
try:
api.upload_folder(
folder_path=f"{model_name}_float16",
repo_id=repo_id,
token=hf_write_token,
path_in_repo=f"{model_name}_float16",
create_remote_repo=True
)
except Exception as e:
if "does not exist" in str(e):
return f"指定したモデル名 '{repo_id}' はHugging Face Hubに存在しません。モデル名を確認してください。"
elif "token" in str(e):
return f"トークンが無効または権限がありません。正しいトークンを確認してください。"
else:
return f"モデルのデプロイに失敗しました。エラー: {e}"
return "モデルをfloat16に変換し、Hugging Faceにデプロイしました。"
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("# モデルの変換とデプロイ")
download_url = gr.Textbox(label="セーフテンソルURL", placeholder="セーフテンソルファイルのダウンロードリンクを入力してください")
hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Hugging Faceの書き込みトークンを入力してください", type="password")
repo_id = gr.Textbox(label="Hugging Face リポジトリID (ユーザー名/モデル名)", placeholder="Hugging FaceのリポジトリIDを入力してください(例:ユーザー名/モデル名)")
output = gr.Textbox(label="結果")
upload_button = gr.Button("ダウンロードしてデプロイ")
upload_button.click(convert_and_deploy, inputs=[download_url, repo_id, hf_write_token], outputs=output)
# アプリの実行
demo.launch()