File size: 2,688 Bytes
cd47483
 
 
 
 
5ac0c97
 
 
 
cd47483
 
ab34078
cd47483
 
 
 
 
1bff30e
 
 
cd47483
4106f96
 
1bff30e
 
 
 
 
 
 
cd47483
 
4106f96
 
62bb2f6
 
a0cefd0
 
 
 
 
 
 
 
3b90025
85b97c4
 
 
62bb2f6
 
85b97c4
 
 
62bb2f6
 
4106f96
 
cd47483
4106f96
 
cd47483
4106f96
cd47483
 
 
5ac0c97
 
 
cd47483
 
 
 
 
 
 
ab34078
 
cd47483
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import warnings

import argilla as rg

# Tasks
TEXTCAT_TASK = "text_classification"
SFT_TASK = "supervised_fine_tuning"

# Hugging Face
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError(
        "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
    )

# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
BASE_URL = os.getenv("BASE_URL", default=None)

_API_KEY = os.getenv("API_KEY")
if _API_KEY:
    API_KEYS = [_API_KEY]
else:
    API_KEYS = [os.getenv("HF_TOKEN")] + [
        os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)
    ]
API_KEYS = [token for token in API_KEYS if token]

# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
    SFT_AVAILABLE = True
    passed_pre_query_template = os.getenv("MAGPIE_PRE_QUERY_TEMPLATE")
    if passed_pre_query_template.lower() in llama_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
    elif passed_pre_query_template.lower() in qwen_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
    else:
        MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
    option in MODEL.lower() for option in llama_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
    option in MODEL.lower() for option in qwen_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"

if BASE_URL:
    SFT_AVAILABLE = False

if not SFT_AVAILABLE:
    warnings.warn(
        message="`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints to generate chat data."
    )
    MAGPIE_PRE_QUERY_TEMPLATE = None

# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"

# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
    ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
    ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")

if not ARGILLA_API_URL or not ARGILLA_API_KEY:
    warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
    argilla_client = None
else:
    argilla_client = rg.Argilla(
        api_url=ARGILLA_API_URL,
        api_key=ARGILLA_API_KEY,
    )