File size: 2,094 Bytes
a55dc79
 
 
ebb4c7f
a55dc79
 
 
ebb4c7f
a55dc79
 
 
 
 
 
 
 
 
 
 
 
 
ebb4c7f
a55dc79
 
 
 
ebb4c7f
a55dc79
 
ebb4c7f
a55dc79
 
 
 
 
 
 
 
 
 
 
 
77c95fb
 
a55dc79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# handler.py  ——  放在模型仓库根目录
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch


class EndpointHandler:
    """
    Hugging Face Inference Endpoints 约定的自定义入口:
      • __init__(model_dir, **kwargs)   —— 加载模型
      • __call__(inputs: Dict) -> Dict  —— 处理一次请求
    """

    def __init__(self, model_dir: str, **kwargs):
        # 1️⃣ Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir, trust_remote_code=True
        )

        # 2️⃣ 构建“空壳”模型(不占显存)
        with init_empty_weights():
            base_model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )

        # 3️⃣ 把权重切片加载到两张 GPU
        self.model = load_checkpoint_and_dispatch(
            base_model,
            checkpoint=model_dir,
            device_map="auto",                # 自动分层到 cuda:0 / cuda:1
            dtype=torch.float16,
        )

        # 4️⃣ 生成时常用的生成参数
        self.generation_kwargs = dict(
            max_new_tokens=2048,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        data 格式:
          {
            "inputs": "your prompt here"
          }
        """
        prompt = data["inputs"]

        # ➡️ 只把输入张量放到 cuda:0(与模型第一层同卡)
        inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0")

        # 生成
        with torch.inference_mode():
            output_ids = self.model.generate(**inputs, **self.generation_kwargs)

        generated_text = self.tokenizer.decode(
            output_ids[0], skip_special_tokens=True
        )
        return {"generated_text": generated_text}