davidberenstein1957 HF staff commited on
Commit
71fd9c5
·
1 Parent(s): fedb936

feat: add token roulation logic

Browse files
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,11 +1,11 @@
1
- import os
2
-
3
  import pandas as pd
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.pipeline import Pipeline
6
  from distilabel.steps import KeepColumns
7
  from distilabel.steps.tasks import MagpieGenerator, TextGeneration
8
 
 
 
9
  INFORMATION_SEEKING_PROMPT = (
10
  "You are an AI assistant designed to provide accurate and concise information on a wide"
11
  " range of topics. Your purpose is to assist users in finding specific facts,"
@@ -139,6 +139,7 @@ _STOP_SEQUENCES = [
139
  " \n\n",
140
  ]
141
  DEFAULT_BATCH_SIZE = 1
 
142
 
143
 
144
  def _get_output_mappings(num_turns):
@@ -189,15 +190,18 @@ if __name__ == "__main__":
189
 
190
 
191
  def get_pipeline(num_turns, num_rows, system_prompt):
 
192
  input_mappings = _get_output_mappings(num_turns)
193
  output_mappings = input_mappings
 
 
194
  if num_turns == 1:
195
  with Pipeline(name="sft") as pipeline:
196
  magpie = MagpieGenerator(
197
  llm=InferenceEndpointsLLM(
198
  model_id=MODEL,
199
  tokenizer_id=MODEL,
200
- api_key=os.environ["HF_TOKEN"],
201
  magpie_pre_query_template="llama3",
202
  generation_kwargs={
203
  "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
@@ -218,7 +222,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
218
  llm=InferenceEndpointsLLM(
219
  model_id=MODEL,
220
  tokenizer_id=MODEL,
221
- api_key=os.environ["HF_TOKEN"],
222
  generation_kwargs={"temperature": 0.8, "max_new_tokens": 1024},
223
  ),
224
  system_prompt=system_prompt,
@@ -239,7 +243,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
239
  llm=InferenceEndpointsLLM(
240
  model_id=MODEL,
241
  tokenizer_id=MODEL,
242
- api_key=os.environ["HF_TOKEN"],
243
  magpie_pre_query_template="llama3",
244
  generation_kwargs={
245
  "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
@@ -262,9 +266,12 @@ def get_pipeline(num_turns, num_rows, system_prompt):
262
 
263
 
264
  def get_prompt_generation_step():
 
 
 
265
  generate_description = TextGeneration(
266
  llm=InferenceEndpointsLLM(
267
- api_key=os.environ["HF_TOKEN"],
268
  model_id=MODEL,
269
  tokenizer_id=MODEL,
270
  generation_kwargs={
 
 
 
1
  import pandas as pd
2
  from distilabel.llms import InferenceEndpointsLLM
3
  from distilabel.pipeline import Pipeline
4
  from distilabel.steps import KeepColumns
5
  from distilabel.steps.tasks import MagpieGenerator, TextGeneration
6
 
7
+ from src.distilabel_dataset_generator.utils import HF_TOKENS
8
+
9
  INFORMATION_SEEKING_PROMPT = (
10
  "You are an AI assistant designed to provide accurate and concise information on a wide"
11
  " range of topics. Your purpose is to assist users in finding specific facts,"
 
139
  " \n\n",
140
  ]
141
  DEFAULT_BATCH_SIZE = 1
142
+ TOKEN_INDEX = 0
143
 
144
 
145
  def _get_output_mappings(num_turns):
 
190
 
191
 
192
  def get_pipeline(num_turns, num_rows, system_prompt):
193
+ global TOKEN_INDEX
194
  input_mappings = _get_output_mappings(num_turns)
195
  output_mappings = input_mappings
196
+ api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
197
+ TOKEN_INDEX += 1
198
  if num_turns == 1:
199
  with Pipeline(name="sft") as pipeline:
200
  magpie = MagpieGenerator(
201
  llm=InferenceEndpointsLLM(
202
  model_id=MODEL,
203
  tokenizer_id=MODEL,
204
+ api_key=api_key,
205
  magpie_pre_query_template="llama3",
206
  generation_kwargs={
207
  "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
 
222
  llm=InferenceEndpointsLLM(
223
  model_id=MODEL,
224
  tokenizer_id=MODEL,
225
+ api_key=api_key,
226
  generation_kwargs={"temperature": 0.8, "max_new_tokens": 1024},
227
  ),
228
  system_prompt=system_prompt,
 
243
  llm=InferenceEndpointsLLM(
244
  model_id=MODEL,
245
  tokenizer_id=MODEL,
246
+ api_key=api_key,
247
  magpie_pre_query_template="llama3",
248
  generation_kwargs={
249
  "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
 
266
 
267
 
268
  def get_prompt_generation_step():
269
+ global TOKEN_INDEX
270
+ api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
271
+ TOKEN_INDEX += 1
272
  generate_description = TextGeneration(
273
  llm=InferenceEndpointsLLM(
274
+ api_key=api_key,
275
  model_id=MODEL,
276
  tokenizer_id=MODEL,
277
  generation_kwargs={
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  from gradio.oauth import (
3
  OAUTH_CLIENT_ID,
@@ -8,6 +10,9 @@ from gradio.oauth import (
8
  )
9
  from huggingface_hub import whoami
10
 
 
 
 
11
  _CHECK_IF_SPACE_IS_SET = (
12
  all(
13
  [
 
1
+ import os
2
+
3
  import gradio as gr
4
  from gradio.oauth import (
5
  OAUTH_CLIENT_ID,
 
10
  )
11
  from huggingface_hub import whoami
12
 
13
+ HF_TOKENS = os.getenv("HF_TOKEN") + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
14
+ HF_TOKENS = [token for token in HF_TOKENS if token]
15
+
16
  _CHECK_IF_SPACE_IS_SET = (
17
  all(
18
  [