|
import os |
|
|
|
|
|
INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}" |
|
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta" |
|
CLIENT_MODEL_KWARGS = { |
|
"max_tokens": 800, |
|
"temperature": 0.6, |
|
} |
|
|
|
GUIDE_KWARGS = { |
|
"expert_model": "HuggingFaceH4/zephyr-7b-beta", |
|
|
|
"inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta", |
|
|
|
"llm_backend": "HFChat", |
|
"classifier_kwargs": { |
|
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", |
|
"inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", |
|
"batch_size": 8, |
|
}, |
|
} |
|
|
|
|
|
def process_config(config): |
|
if "HF_TOKEN" not in os.environ: |
|
raise ValueError("Please set the HF_TOKEN environment variable.") |
|
client_kwargs = {} |
|
if "client_llm" in config: |
|
if "model_id" in config["client_llm"]: |
|
client_kwargs["model_id"] = config["client_llm"]["model_id"] |
|
else: |
|
raise ValueError("config.yaml is missing client model_id.") |
|
if "url" in config["client_llm"]: |
|
client_kwargs["inference_server_url"] = config["client_llm"]["url"] |
|
else: |
|
raise ValueError("config.yaml is missing client url.") |
|
client_kwargs["api_key"] = os.getenv("HF_TOKEN") |
|
client_kwargs["llm_backend"] = "HFChat" |
|
client_kwargs["temperature"] = CLIENT_MODEL_KWARGS["temperature"] |
|
client_kwargs["max_tokens"] = CLIENT_MODEL_KWARGS["max_tokens"] |
|
else: |
|
raise ValueError("config.yaml is missing client_llm settings.") |
|
|
|
guide_kwargs = {} |
|
if "expert_llm" in config: |
|
if "model_id" in config["expert_llm"]: |
|
guide_kwargs["expert_model"] = config["expert_llm"]["model_id"] |
|
else: |
|
raise ValueError("config.yaml is missing expert model_id.") |
|
if "url" in config["expert_llm"]: |
|
guide_kwargs["inference_server_url"] = config["expert_llm"]["url"] |
|
else: |
|
raise ValueError("config.yaml is missing expert url.") |
|
guide_kwargs["api_key"] = os.getenv("HF_TOKEN") |
|
guide_kwargs["llm_backend"] = "HFChat" |
|
else: |
|
raise ValueError("config.yaml is missing expert_llm settings.") |
|
|
|
if "classifier_llm" in config: |
|
if "model_id" in config["classifier_llm"]: |
|
guide_kwargs["classifier_kwargs"]["model_id"] = config["classifier_llm"]["model_id"] |
|
else: |
|
raise ValueError("config.yaml is missing classifier model_id.") |
|
if "url" in config["classifier_llm"]: |
|
guide_kwargs["classifier_kwargs"]["inference_server_url"] = config["classifier_llm"]["url"] |
|
else: |
|
raise ValueError("config.yaml is missing classifier url.") |
|
if "batch_size" in config["classifier_llm"]: |
|
guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"] |
|
else: |
|
raise ValueError("config.yaml is missing classifier batch_size.") |
|
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") |
|
else: |
|
raise ValueError("config.yaml is missing classifier_llm settings.") |
|
|
|
return client_kwargs, guide_kwargs |
|
|
|
|