Spaces:
Running
Running
import gradio as gr | |
import requests | |
import torch | |
from transformers import AutoModel, AutoConfig | |
from huggingface_hub import HfApi | |
import safetensors | |
import os | |
def convert_and_deploy(url, repo_id, hf_token): | |
# セーフテンソルファイルをダウンロード | |
response = requests.get(url) | |
if response.status_code != 200: | |
return "ファイルのダウンロードに失敗しました。URLを確認してください。" | |
# ファイルを保存 | |
file_path = "model.safetensors" | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
# ファイルの存在を確認 | |
if not os.path.exists(file_path): | |
return "ファイルが正しく保存されませんでした。" | |
# ファイルの内容を確認 | |
with open(file_path, "rb") as f: | |
content = f.read(100) # 先頭100バイトを読み込む | |
if not content: | |
return "ファイルが空です。" | |
# モデルを読み込み | |
try: | |
# モデルの構成を取得 | |
model_name = repo_id.split('/')[-1] # モデル名を取得 | |
config = AutoConfig.from_pretrained(model_name, token=hf_token) | |
# モデルを構成に基づいて初期化 | |
model = AutoModel.from_config(config, torch_dtype=torch.float16) | |
# セーフテンソルファイルからモデルの状態を読み込み | |
with safetensors.safe_open(file_path, framework="pt") as f: | |
state_dict = {k: f.get_tensor(k) for k in f.keys()} | |
# BF16からFP16に変換 | |
state_dict = {k: v.to(torch.float16) for k, v in state_dict.items()} | |
# モデルの状態を設定 | |
model.load_state_dict(state_dict) | |
except Exception as e: | |
return f"モデルの読み込みに失敗しました。エラー: {e}" | |
# モデルをfloat16形式で保存 | |
try: | |
model.save_pretrained(f"{model_name}_float16", torch_dtype=torch.float16) | |
except Exception as e: | |
return f"モデルの保存に失敗しました。エラー: {e}" | |
# モデルをHugging Faceにデプロイ | |
api = HfApi() | |
try: | |
api.upload_folder( | |
folder_path=f"{model_name}_float16", | |
repo_id=repo_id, | |
token=hf_token, | |
path_in_repo=f"{model_name}_float16", | |
create_remote_repo=True | |
) | |
except Exception as e: | |
return f"モデルのデプロイに失敗しました。エラー: {e}" | |
return "モデルをfloat16に変換し、Hugging Faceにデプロイしました。" | |
# Gradioインターフェースの作成 | |
iface = gr.Interface( | |
fn=convert_and_deploy, | |
inputs=[ | |
gr.Text(label="セーフテンソルURL"), | |
gr.Text(label="Hugging Face リポジトリID (ユーザー名/モデル名)"), | |
gr.Text(label="Hugging Face Write Token") | |
], | |
outputs=gr.Text(label="結果"), | |
title="モデルの変換とデプロイ", | |
description="セーフテンソルURL、Hugging Face リポジトリID (ユーザー名/モデル名)、およびHugging Face Write Tokenを入力して、モデルをfloat16に変換し、Hugging Faceにデプロイします。" | |
) | |
# インターフェースの起動 | |
iface.launch() |