File size: 3,419 Bytes
312035b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os

# Default client
INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
CLIENT_MODEL_KWARGS = {
    "max_tokens": 800,
    "temperature": 0.6,
}

GUIDE_KWARGS = {
    "expert_model": "HuggingFaceH4/zephyr-7b-beta",
                  # "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
                  # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
    "llm_backend": "HFChat",
    "classifier_kwargs": {
        "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
        "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
        "batch_size": 8,
    },
}


def process_config(config):
    if "HF_TOKEN" not in os.environ:
        raise ValueError("Please set the HF_TOKEN environment variable.")
    client_kwargs = {}
    if "client_llm" in config:
        if "model_id" in config["client_llm"]:
            client_kwargs["model_id"] = config["client_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing client model_id.")
        if "url" in config["client_llm"]:
            client_kwargs["inference_server_url"] = config["client_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing client url.")
        client_kwargs["api_key"] = os.getenv("HF_TOKEN")
        client_kwargs["llm_backend"] = "HFChat"
        client_kwargs["temperature"] = CLIENT_MODEL_KWARGS["temperature"]
        client_kwargs["max_tokens"] = CLIENT_MODEL_KWARGS["max_tokens"]
    else:
        raise ValueError("config.yaml is missing client_llm settings.")
    
    guide_kwargs = {}
    if "expert_llm" in config:
        if "model_id" in config["expert_llm"]:
            guide_kwargs["expert_model"] = config["expert_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing expert model_id.")
        if "url" in config["expert_llm"]:
            guide_kwargs["inference_server_url"] = config["expert_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing expert url.")
        guide_kwargs["api_key"] = os.getenv("HF_TOKEN")
        guide_kwargs["llm_backend"] = "HFChat"
    else:
        raise ValueError("config.yaml is missing expert_llm settings.")

    if "classifier_llm" in config:
        if "model_id" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"]
        else:
            raise ValueError("config.yaml is missing classifier model_id.")
        if "url" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"]
        else:
            raise ValueError("config.yaml is missing classifier url.")
        if "batch_size" in config["classifier_llm"]:
            guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"]
        else:
            raise ValueError("config.yaml is missing classifier batch_size.")
        guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN")  # classifier api key
    else:
        raise ValueError("config.yaml is missing classifier_llm settings.")

    return client_kwargs, guide_kwargs