File size: 3,419 Bytes
312035b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
# Default client
INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
CLIENT_MODEL_KWARGS = {
"max_tokens": 800,
"temperature": 0.6,
}
GUIDE_KWARGS = {
"expert_model": "HuggingFaceH4/zephyr-7b-beta",
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
"inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
"llm_backend": "HFChat",
"classifier_kwargs": {
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"batch_size": 8,
},
}
def process_config(config):
if "HF_TOKEN" not in os.environ:
raise ValueError("Please set the HF_TOKEN environment variable.")
client_kwargs = {}
if "client_llm" in config:
if "model_id" in config["client_llm"]:
client_kwargs["model_id"] = config["client_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing client model_id.")
if "url" in config["client_llm"]:
client_kwargs["inference_server_url"] = config["client_llm"]["url"]
else:
raise ValueError("config.yaml is missing client url.")
client_kwargs["api_key"] = os.getenv("HF_TOKEN")
client_kwargs["llm_backend"] = "HFChat"
client_kwargs["temperature"] = CLIENT_MODEL_KWARGS["temperature"]
client_kwargs["max_tokens"] = CLIENT_MODEL_KWARGS["max_tokens"]
else:
raise ValueError("config.yaml is missing client_llm settings.")
guide_kwargs = {}
if "expert_llm" in config:
if "model_id" in config["expert_llm"]:
guide_kwargs["expert_model"] = config["expert_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing expert model_id.")
if "url" in config["expert_llm"]:
guide_kwargs["inference_server_url"] = config["expert_llm"]["url"]
else:
raise ValueError("config.yaml is missing expert url.")
guide_kwargs["api_key"] = os.getenv("HF_TOKEN")
guide_kwargs["llm_backend"] = "HFChat"
else:
raise ValueError("config.yaml is missing expert_llm settings.")
if "classifier_llm" in config:
if "model_id" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"]
else:
raise ValueError("config.yaml is missing classifier model_id.")
if "url" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"]
else:
raise ValueError("config.yaml is missing classifier url.")
if "batch_size" in config["classifier_llm"]:
guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"]
else:
raise ValueError("config.yaml is missing classifier batch_size.")
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
else:
raise ValueError("config.yaml is missing classifier_llm settings.")
return client_kwargs, guide_kwargs
|