File size: 4,377 Bytes
bfefc44
 
99a3dfb
bfefc44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3dfb
bfefc44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import requests
from typing import Optional


# 定数定義
API_BASE_URL = "https://huggingface.co/api/models/"
LLAMA_CPP_SOURCE = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/convert_hf_to_gguf.py"
QUANTIZE = {
    "IQ4_XS": 4.25,
    "Q4_K_M": 4.9,
    "Q5_K_M": 5.7,
    "Q6_K": 6.6,
}

def get_model_api_info(model_name: str) -> dict:
    """
    Hugging Face Hub APIから指定モデルの情報を取得する。

    Args:
        model_name (str): モデル名(例: "meta-llama/Llama-3.3-70B-Instruct")

    Returns:
        dict: モデル情報の辞書。取得に失敗した場合は空の辞書を返す。
    """
    api_url = f"{API_BASE_URL}{model_name}"
    try:
        response = requests.get(api_url)
        response.raise_for_status()
        return response.json()
    except:
        return {}

def is_architecture_supported(architecture: str) -> Optional[bool]:
    """
    llama.cppのソースコード内に指定アーキテクチャが含まれているかチェックする。

    Args:
        architecture (str): アーキテクチャ名

    Returns:
        Optional[bool]:
            - True: アーキテクチャがサポートされている場合
            - False: サポートされていない場合
            - None: ソースコードの取得に失敗した場合
    """
    try:
        response = requests.get(LLAMA_CPP_SOURCE)
        response.raise_for_status()
        content = response.text
        return architecture in content
    except:
        return None

def estimate_gpu_memory(model_name: str) -> str:
    """
    指定したモデル名からAPI情報を取得し、safetensors内の各精度パラメータサイズの合算値から
    GPUメモリ必要量を概算する。

    Args:
        model_name (str): モデル名

    Returns:
        str: GPUメモリ必要量などの情報を含むメッセージ文字列
    """
    result_lines = []

    model_info = get_model_api_info(model_name)
    if not model_info:
        result_lines.append(f"エラー: モデル '{model_name}' の情報が取得できませんでした。")
        return "\n".join(result_lines)

    parameters = model_info.get("safetensors", {}).get("parameters")
    if parameters is None:
        result_lines.append("safetensorsの情報が見つかりませんでした。")
        return "\n".join(result_lines)

    architectures = model_info.get("config", {}).get("architectures", [])
    if not architectures or architectures[0] is None:
        result_lines.append("モデルアーキテクチャの情報が見つかりませんでした。")
        return "\n".join(result_lines)

    # アーキテクチャの対応確認
    architecture = architectures[0]
    arch_supported = is_architecture_supported(architecture)
    if arch_supported is None:
        result_lines.append("llama.cppのソースコード参照に失敗しました。")
        return "\n".join(result_lines)
    if not arch_supported:
        result_lines.append(f"{architecture}はllama.cppにて未対応のアーキテクチャです。")
        return "\n".join(result_lines)

    result_lines.append(f"モデル: {model_name}")
    result_lines.append(f"アーキテクチャ: {architecture}")
    result_lines.append("")
    result_lines.append("各量子化モデルの動作に必要な概算のGPUメモリサイズは")

    # GPUメモリサイズの計算
    total_params = sum(parameters.values())
    for quant, multiplier in QUANTIZE.items():
        gpu_memory = total_params * multiplier / 8 / (1024 ** 3) * 1.3
        result_lines.append(f"【{quant}】約 {gpu_memory:.2f} GB")

    result_lines.append("となります。")
    result_lines.append("")
    result_lines.append("参考:Qの数字が量子化ビット数です。")
    return "\n".join(result_lines)

# Gradio インターフェースの定義
iface = gr.Interface(
    fn=estimate_gpu_memory,
    inputs=gr.Textbox(label="モデル名 (例: meta-llama/Llama-3.3-70B-Instruct)"),
    outputs="text",
    title="Quantized model memory estimator",
    description=("Hugging Face Hub APIから取得したsafetensorsの情報をもとに、llama.cppで量子化を行ったモデルの動作に必要なGPUメモリサイズを概算(GB単位)で計算します。")
)

iface.launch()