arthrod commited on
Commit
253193c
ยท
2 Parent(s): 3498f3f 5302b49

Merge remote-tracking branch 'origin/main'

Browse files
.gitmodules ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [submodule "vllm"]
2
+ update = checkout
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM docker.io/library/python:3.10@sha256:76f22e4ce53774c1f5eb0ba145edb57b908e7aa329fee75eca69b511c1d0cd8a
2
+
3
+ WORKDIR /home/user/app
4
+
5
+ # Install uv and create virtual environment
6
+ RUN python -m pip install uv && \
7
+ uv venv --python=3.12 .venv && \
8
+ . .venv/bin/activate
9
+
10
+ # Base pip updates and initial packages
11
+ RUN pip install --no-cache-dir pip -U && \
12
+ pip install --no-cache-dir datasets "huggingface-hub>=0.19" "hf-transfer>=0.1.4" "protobuf<4" "click<8.1" "pydantic~=1.0"
13
+
14
+ # System packages installation
15
+ RUN --mount=target=/tmp/packages.txt,source=packages.txt \
16
+ apt-get update && \
17
+ xargs -r -a /tmp/packages.txt apt-get install -y && \
18
+ apt-get install -y curl && \
19
+ curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
20
+ apt-get install -y nodejs && \
21
+ rm -rf /var/lib/apt/lists/* && \
22
+ apt-get clean
23
+
24
+ # Fakeroot setup
25
+ RUN apt-get update && \
26
+ apt-get install -y fakeroot && \
27
+ mv /usr/bin/apt-get /usr/bin/.apt-get && \
28
+ echo '#!/usr/bin/env sh\nfakeroot /usr/bin/.apt-get $@' > /usr/bin/apt-get && \
29
+ chmod +x /usr/bin/apt-get && \
30
+ rm -rf /var/lib/apt/lists/* && \
31
+ useradd -m -u 1000 user
32
+
33
+ # Copy files and install requirements
34
+ COPY --chown=1000:1000 --from=root / /
35
+
36
+ # Install requirements in specific order
37
+ RUN . .venv/bin/activate && \
38
+ pip install -r .venv/requirements-cpu.txt && \
39
+ pip install -e .venv/. && \
40
+ pip install -e .
41
+
42
+ RUN pip freeze > /tmp/freeze.txt
43
+
44
+ # Additional system packages
45
+ RUN apt-get update && \
46
+ apt-get install -y git git-lfs ffmpeg libsm6 libxext6 cmake rsync libgl1-mesa-glx && \
47
+ rm -rf /var/lib/apt/lists/* && \
48
+ git lfs install
49
+
50
+ # Install Gradio and related packages
51
+ RUN . .venv/bin/activate && \
52
+ pip install --no-cache-dir gradio[oauth]==5.8.0 "uvicorn>=0.14.0" spaces
53
+
54
+ COPY --link --chown=1000 ./ /home/user/app
55
+ COPY --from=pipfreeze --link --chown=1000 /tmp/freeze.txt /tmp/freeze.txt
a.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ def modify_synthetic_imports(file_path):
5
+ """Modify imports of synthetic_dataset_generator to add src."""
6
+ try:
7
+ with open(file_path, 'r') as file:
8
+ content = file.read()
9
+
10
+ # Replace both import patterns to add src.
11
+ modified_content = re.sub(
12
+ r'from src.synthetic_dataset_generator',
13
+ 'from src.synthetic_dataset_generator',
14
+ content
15
+ )
16
+ modified_content = re.sub(
17
+ r'import src.synthetic_dataset_generator',
18
+ 'import src.synthetic_dataset_generator',
19
+ modified_content
20
+ )
21
+
22
+ # Only write if changes were made
23
+ if modified_content != content:
24
+ with open(file_path, 'w') as file:
25
+ file.write(modified_content)
26
+ print(f"Modified imports in: {file_path}")
27
+
28
+ except Exception as e:
29
+ print(f"Error processing {file_path}: {str(e)}")
30
+
31
+ def process_directory(start_path):
32
+ """Recursively process all Python files in directory"""
33
+ for root, _, files in os.walk(start_path):
34
+ for file in files:
35
+ if file.endswith('.py'):
36
+ file_path = os.path.join(root, file)
37
+ modify_synthetic_imports(file_path)
38
+
39
+ if __name__ == "__main__":
40
+ import sys
41
+ if len(sys.argv) != 2:
42
+ print("Usage: python script.py <directory_path>")
43
+ sys.exit(1)
44
+
45
+ directory_path = sys.argv[1]
46
+ if not os.path.isdir(directory_path):
47
+ print(f"Error: {directory_path} is not a valid directory")
48
+ sys.exit(1)
49
+
50
+ process_directory(directory_path)
51
+ print("Processing complete!")
app.py CHANGED
@@ -1,3 +1,3 @@
1
- from synthetic_dataset_generator import launch
2
 
3
  launch()
 
1
+ from src.synthetic_dataset_generator import launch
2
 
3
  launch()
examples/argilla-deployment.py CHANGED
@@ -6,7 +6,7 @@
6
  # ///
7
  import os
8
 
9
- from synthetic_dataset_generator import launch
10
 
11
  # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
12
  os.environ["HF_TOKEN"] = "hf_..."
 
6
  # ///
7
  import os
8
 
9
+ from src.synthetic_dataset_generator import launch
10
 
11
  # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
12
  os.environ["HF_TOKEN"] = "hf_..."
examples/hf-dedicated-or-tgi-deployment.py CHANGED
@@ -6,7 +6,7 @@
6
  # ///
7
  import os
8
 
9
- from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
  os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
 
6
  # ///
7
  import os
8
 
9
+ from src.synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
  os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
examples/hf-serverless-deployment.py CHANGED
@@ -6,7 +6,7 @@
6
  # ///
7
  import os
8
 
9
- from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
  os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
 
6
  # ///
7
  import os
8
 
9
+ from src.synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
  os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
examples/ollama-deployment.py CHANGED
@@ -8,7 +8,7 @@
8
  # ollama run qwen2.5:32b-instruct-q5_K_S
9
  import os
10
 
11
- from synthetic_dataset_generator import launch
12
 
13
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
14
  os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
 
8
  # ollama run qwen2.5:32b-instruct-q5_K_S
9
  import os
10
 
11
+ from src.synthetic_dataset_generator import launch
12
 
13
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
14
  os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
examples/openai-deployment.py CHANGED
@@ -7,7 +7,7 @@
7
 
8
  import os
9
 
10
- from synthetic_dataset_generator import launch
11
 
12
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
  os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
 
7
 
8
  import os
9
 
10
+ from src.synthetic_dataset_generator import launch
11
 
12
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
  os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
examples/vllm-deployment.py CHANGED
@@ -7,7 +7,7 @@
7
  # vllm serve Qwen/Qwen2.5-1.5B-Instruct
8
  import os
9
 
10
- from synthetic_dataset_generator import launch
11
 
12
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
  os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
 
7
  # vllm serve Qwen/Qwen2.5-1.5B-Instruct
8
  import os
9
 
10
+ from src.synthetic_dataset_generator import launch
11
 
12
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
  os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
src/synthetic_dataset_generator/__init__.py CHANGED
@@ -2,7 +2,7 @@ import inspect
2
 
3
  from gradio import TabbedInterface
4
 
5
- from synthetic_dataset_generator import ( # noqa
6
  _distiset,
7
  _inference_endpoints,
8
  )
@@ -13,7 +13,7 @@ def launch(*args, **kwargs):
13
  Based on the `TabbedInterface` from Gradio.
14
  Parameters: https://www.gradio.app/docs/gradio/tabbedinterface
15
  """
16
- from synthetic_dataset_generator.app import demo
17
 
18
  return demo.launch(*args, **kwargs)
19
 
 
2
 
3
  from gradio import TabbedInterface
4
 
5
+ from src.synthetic_dataset_generator import ( # noqa
6
  _distiset,
7
  _inference_endpoints,
8
  )
 
13
  Based on the `TabbedInterface` from Gradio.
14
  Parameters: https://www.gradio.app/docs/gradio/tabbedinterface
15
  """
16
+ from src.synthetic_dataset_generator.app import demo
17
 
18
  return demo.launch(*args, **kwargs)
19
 
src/synthetic_dataset_generator/__main__.py CHANGED
@@ -1,4 +1,4 @@
1
  if __name__ == "__main__":
2
- from synthetic_dataset_generator import launch
3
 
4
  launch()
 
1
  if __name__ == "__main__":
2
+ from src.synthetic_dataset_generator import launch
3
 
4
  launch()
src/synthetic_dataset_generator/app.py CHANGED
@@ -1,10 +1,10 @@
1
- from synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
 
3
- # from synthetic_dataset_generator.apps.eval import app as eval_app
4
- from synthetic_dataset_generator.apps.rag import app as rag_app
5
- from synthetic_dataset_generator.apps.about import app as about_app
6
- from synthetic_dataset_generator.apps.chat import app as chat_app
7
- from synthetic_dataset_generator.apps.textcat import app as textcat_app
8
 
9
  theme = "argilla/argilla-theme"
10
 
 
1
+ from src.synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
 
3
+ # from src.synthetic_dataset_generator.apps.eval import app as eval_app
4
+ from src.synthetic_dataset_generator.apps.rag import app as rag_app
5
+ from src.synthetic_dataset_generator.apps.about import app as about_app
6
+ from src.synthetic_dataset_generator.apps.chat import app as chat_app
7
+ from src.synthetic_dataset_generator.apps.textcat import app as textcat_app
8
 
9
  theme = "argilla/argilla-theme"
10
 
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -12,8 +12,8 @@ from huggingface_hub import HfApi, upload_file, repo_exists
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
- from synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
- from synthetic_dataset_generator.utils import get_argilla_client
17
 
18
 
19
  def validate_argilla_user_workspace_dataset(
 
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
+ from src.synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
+ from src.synthetic_dataset_generator.utils import get_argilla_client
17
 
18
 
19
  def validate_argilla_user_workspace_dataset(
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -13,7 +13,7 @@ from gradio.oauth import OAuthToken
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
- from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
  load_dataset_from_hub,
@@ -24,15 +24,15 @@ from synthetic_dataset_generator.apps.base import (
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
- from synthetic_dataset_generator.constants import (
28
  BASE_URL,
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
31
  MODEL_COMPLETION,
32
  SFT_AVAILABLE,
33
  )
34
- from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
35
- from synthetic_dataset_generator.pipelines.chat import (
36
  DEFAULT_DATASET_DESCRIPTIONS,
37
  generate_pipeline_code,
38
  get_follow_up_generator,
@@ -41,12 +41,11 @@ from synthetic_dataset_generator.pipelines.chat import (
41
  get_response_generator,
42
  get_sentence_pair_generator,
43
  )
44
- from synthetic_dataset_generator.pipelines.embeddings import (
45
  get_embeddings,
46
  get_sentence_embedding_dimensions,
47
  )
48
- from synthetic_dataset_generator.utils import (
49
- column_to_list,
50
  get_argilla_client,
51
  get_org_dropdown,
52
  get_random_repo_name,
 
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
+ from src.synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
  load_dataset_from_hub,
 
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from src.synthetic_dataset_generator.constants import (
28
  BASE_URL,
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
31
  MODEL_COMPLETION,
32
  SFT_AVAILABLE,
33
  )
34
+ from src.synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
35
+ from src.synthetic_dataset_generator.pipelines.chat import (
36
  DEFAULT_DATASET_DESCRIPTIONS,
37
  generate_pipeline_code,
38
  get_follow_up_generator,
 
41
  get_response_generator,
42
  get_sentence_pair_generator,
43
  )
44
+ from src.synthetic_dataset_generator.pipelines.embeddings import (
45
  get_embeddings,
46
  get_sentence_embedding_dimensions,
47
  )
48
+ from src.synthetic_dataset_generator.utils import (
 
49
  get_argilla_client,
50
  get_org_dropdown,
51
  get_random_repo_name,
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -17,7 +17,7 @@ from gradio.oauth import OAuthToken #
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
  from huggingface_hub import HfApi
19
 
20
- from synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
22
  get_iframe,
23
  hide_success_message,
@@ -27,17 +27,17 @@ from synthetic_dataset_generator.apps.base import (
27
  validate_argilla_user_workspace_dataset,
28
  validate_push_to_hub,
29
  )
30
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
31
- from synthetic_dataset_generator.pipelines.embeddings import (
32
  get_embeddings,
33
  get_sentence_embedding_dimensions,
34
  )
35
- from synthetic_dataset_generator.pipelines.eval import (
36
  generate_pipeline_code,
37
  get_custom_evaluator,
38
  get_ultrafeedback_evaluator,
39
  )
40
- from synthetic_dataset_generator.utils import (
41
  column_to_list,
42
  extract_column_names,
43
  get_argilla_client,
 
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
  from huggingface_hub import HfApi
19
 
20
+ from src.synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
22
  get_iframe,
23
  hide_success_message,
 
27
  validate_argilla_user_workspace_dataset,
28
  validate_push_to_hub,
29
  )
30
+ from src.synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
31
+ from src.synthetic_dataset_generator.pipelines.embeddings import (
32
  get_embeddings,
33
  get_sentence_embedding_dimensions,
34
  )
35
+ from src.synthetic_dataset_generator.pipelines.eval import (
36
  generate_pipeline_code,
37
  get_custom_evaluator,
38
  get_ultrafeedback_evaluator,
39
  )
40
+ from src.synthetic_dataset_generator.utils import (
41
  column_to_list,
42
  extract_column_names,
43
  get_argilla_client,
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -13,7 +13,7 @@ from gradio.oauth import OAuthToken
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
- from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
  load_dataset_from_hub,
@@ -24,13 +24,13 @@ from synthetic_dataset_generator.apps.base import (
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
28
- from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
- from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
31
  get_sentence_embedding_dimensions,
32
  )
33
- from synthetic_dataset_generator.pipelines.rag import (
34
  DEFAULT_DATASET_DESCRIPTIONS,
35
  generate_pipeline_code,
36
  get_chunks_generator,
@@ -38,7 +38,7 @@ from synthetic_dataset_generator.pipelines.rag import (
38
  get_response_generator,
39
  get_sentence_pair_generator,
40
  )
41
- from synthetic_dataset_generator.utils import (
42
  column_to_list,
43
  get_argilla_client,
44
  get_org_dropdown,
 
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
+ from src.synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
  load_dataset_from_hub,
 
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from src.synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
28
+ from src.synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
+ from src.synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
31
  get_sentence_embedding_dimensions,
32
  )
33
+ from src.synthetic_dataset_generator.pipelines.rag import (
34
  DEFAULT_DATASET_DESCRIPTIONS,
35
  generate_pipeline_code,
36
  get_chunks_generator,
 
38
  get_response_generator,
39
  get_sentence_pair_generator,
40
  )
41
+ from src.synthetic_dataset_generator.utils import (
42
  column_to_list,
43
  get_argilla_client,
44
  get_org_dropdown,
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -10,7 +10,7 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value
10
  from distilabel.distiset import Distiset
11
  from huggingface_hub import HfApi
12
 
13
- from synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
@@ -19,20 +19,20 @@ from synthetic_dataset_generator.apps.base import (
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
- from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
24
- from synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
26
  get_sentence_embedding_dimensions,
27
  )
28
- from synthetic_dataset_generator.pipelines.textcat import (
29
  DEFAULT_DATASET_DESCRIPTIONS,
30
  generate_pipeline_code,
31
  get_labeller_generator,
32
  get_prompt_generator,
33
  get_textcat_generator,
34
  )
35
- from synthetic_dataset_generator.utils import (
36
  get_argilla_client,
37
  get_org_dropdown,
38
  get_preprocess_labels,
 
10
  from distilabel.distiset import Distiset
11
  from huggingface_hub import HfApi
12
 
13
+ from src.synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
 
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
+ from src.synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
+ from src.synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
24
+ from src.synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
26
  get_sentence_embedding_dimensions,
27
  )
28
+ from src.synthetic_dataset_generator.pipelines.textcat import (
29
  DEFAULT_DATASET_DESCRIPTIONS,
30
  generate_pipeline_code,
31
  get_labeller_generator,
32
  get_prompt_generator,
33
  get_textcat_generator,
34
  )
35
+ from src.synthetic_dataset_generator.utils import (
36
  get_argilla_client,
37
  get_org_dropdown,
38
  get_preprocess_labels,
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
5
  from distilabel.steps.tasks import TextGeneration
6
 
7
- from synthetic_dataset_generator.constants import (
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
 
4
  from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
5
  from distilabel.steps.tasks import TextGeneration
6
 
7
+ from src.synthetic_dataset_generator.constants import (
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -6,11 +6,11 @@ from distilabel.steps.tasks import (
6
  TextGeneration,
7
  )
8
 
9
- from synthetic_dataset_generator.constants import (
10
  MAGPIE_PRE_QUERY_TEMPLATE,
11
  MAX_NUM_TOKENS,
12
  )
13
- from synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
14
 
15
  INFORMATION_SEEKING_PROMPT = (
16
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
6
  TextGeneration,
7
  )
8
 
9
+ from src.synthetic_dataset_generator.constants import (
10
  MAGPIE_PRE_QUERY_TEMPLATE,
11
  MAX_NUM_TOKENS,
12
  )
13
+ from src.synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
14
 
15
  INFORMATION_SEEKING_PROMPT = (
16
  "You are an AI assistant designed to provide accurate and concise information on a wide"
src/synthetic_dataset_generator/pipelines/embeddings.py CHANGED
@@ -3,7 +3,7 @@ from typing import List
3
  from sentence_transformers import SentenceTransformer
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
- from synthetic_dataset_generator.constants import STATIC_EMBEDDING_MODEL
7
 
8
  static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL)
9
  model = SentenceTransformer(modules=[static_embedding])
 
3
  from sentence_transformers import SentenceTransformer
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
+ from src.synthetic_dataset_generator.constants import STATIC_EMBEDDING_MODEL
7
 
8
  static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL)
9
  model = SentenceTransformer(modules=[static_embedding])
src/synthetic_dataset_generator/pipelines/eval.py CHANGED
@@ -7,9 +7,9 @@ from distilabel.steps.tasks import (
7
  UltraFeedback,
8
  )
9
 
10
- from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
11
- from synthetic_dataset_generator.pipelines.base import _get_next_api_key
12
- from synthetic_dataset_generator.utils import extract_column_names
13
 
14
 
15
  def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
 
7
  UltraFeedback,
8
  )
9
 
10
+ from src.synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
11
+ from src.synthetic_dataset_generator.pipelines.base import _get_next_api_key
12
+ from src.synthetic_dataset_generator.utils import extract_column_names
13
 
14
 
15
  def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
src/synthetic_dataset_generator/pipelines/rag.py CHANGED
@@ -4,8 +4,8 @@ from distilabel.steps.tasks import (
4
  TextGeneration,
5
  )
6
 
7
- from synthetic_dataset_generator.constants import MAX_NUM_TOKENS
8
- from synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
9
 
10
  DEFAULT_DATASET_DESCRIPTIONS = [
11
  "A dataset to retrieve information from legal documents.",
 
4
  TextGeneration,
5
  )
6
 
7
+ from src.synthetic_dataset_generator.constants import MAX_NUM_TOKENS
8
+ from src.synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
9
 
10
  DEFAULT_DATASET_DESCRIPTIONS = [
11
  "A dataset to retrieve information from legal documents.",
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -8,11 +8,11 @@ from distilabel.steps.tasks import (
8
  )
9
  from pydantic import BaseModel, Field
10
 
11
- from synthetic_dataset_generator.constants import (
12
  MAX_NUM_TOKENS,
13
  )
14
- from synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
15
- from synthetic_dataset_generator.utils import get_preprocess_labels
16
 
17
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
18
 
 
8
  )
9
  from pydantic import BaseModel, Field
10
 
11
+ from src.synthetic_dataset_generator.constants import (
12
  MAX_NUM_TOKENS,
13
  )
14
+ from src.synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class
15
+ from src.synthetic_dataset_generator.utils import get_preprocess_labels
16
 
17
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
18
 
src/synthetic_dataset_generator/utils.py CHANGED
@@ -14,7 +14,7 @@ from gradio.oauth import (
14
  from huggingface_hub import whoami
15
  from jinja2 import Environment, meta
16
 
17
- from synthetic_dataset_generator.constants import argilla_client
18
 
19
 
20
  def get_duplicate_button():
 
14
  from huggingface_hub import whoami
15
  from jinja2 import Environment, meta
16
 
17
+ from src.synthetic_dataset_generator.constants import argilla_client
18
 
19
 
20
  def get_duplicate_button():
uv.lock ADDED
The diff for this file is too large to render. See raw diff