Commit
•
099e99c
1
Parent(s):
080f560
refactor: redesign of the generator
Browse files- app.py +23 -49
- demo.py +61 -0
- pdm.lock +0 -0
- pyproject.toml +4 -2
- requirements.txt +148 -7
- src/distilabel_dataset_generator/_tabbedinterface.py +73 -0
- src/distilabel_dataset_generator/apps/base.py +16 -28
- src/distilabel_dataset_generator/apps/eval.py +328 -0
- src/distilabel_dataset_generator/apps/sft.py +248 -296
- src/distilabel_dataset_generator/apps/textcat.py +291 -343
- src/distilabel_dataset_generator/pipelines/sft.py +2 -28
- src/distilabel_dataset_generator/pipelines/textcat.py +2 -28
- src/distilabel_dataset_generator/utils.py +11 -6
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
|
|
3 |
from src.distilabel_dataset_generator.apps.faq import app as faq_app
|
4 |
from src.distilabel_dataset_generator.apps.sft import app as sft_app
|
5 |
from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
|
@@ -23,64 +24,37 @@ css = """
|
|
23 |
background-color: black;
|
24 |
}
|
25 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
"""
|
27 |
|
28 |
-
demo =
|
29 |
[textcat_app, sft_app, faq_app],
|
30 |
["Text Classification", "Supervised Fine-Tuning", "FAQ"],
|
31 |
css=css,
|
32 |
title="""
|
33 |
-
<
|
34 |
-
|
35 |
-
display: flex;
|
36 |
-
align-items: center;
|
37 |
-
justify-content: center;
|
38 |
-
position: relative;
|
39 |
-
padding: 20px 0;
|
40 |
-
}
|
41 |
-
.logo-container {
|
42 |
-
position: absolute;
|
43 |
-
left: 0;
|
44 |
-
top: 0;
|
45 |
-
}
|
46 |
-
.title-container {
|
47 |
-
text-align: center;
|
48 |
-
}
|
49 |
-
@media (max-width: 600px) {
|
50 |
-
.header-container {
|
51 |
-
flex-direction: column;
|
52 |
-
}
|
53 |
-
.logo-container {
|
54 |
-
position: static;
|
55 |
-
margin-bottom: 20px;
|
56 |
-
}
|
57 |
-
}
|
58 |
-
button[role="tab"].selected,
|
59 |
-
button[role="tab"][aria-selected="true"],
|
60 |
-
button[role="tab"][data-tab-id][aria-selected="true"] {
|
61 |
-
background-color: #000000;
|
62 |
-
color: white;
|
63 |
-
border: none;
|
64 |
-
font-size: 16px;
|
65 |
-
font-weight: bold;
|
66 |
-
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
67 |
-
transition: background-color 0.3s ease, color 0.3s ease;
|
68 |
-
}
|
69 |
-
</style>
|
70 |
-
<div class="header-container">
|
71 |
-
<div class="logo-container">
|
72 |
-
<a href="https://github.com/argilla-io/distilabel" target="_blank" rel="noopener noreferrer">
|
73 |
-
<img src="https://distilabel.argilla.io/latest/assets/distilabel-black.svg" alt="Distilabel Logo" style="width: 150px; height: auto;">
|
74 |
-
</a>
|
75 |
-
</div>
|
76 |
-
<div class="title-container">
|
77 |
-
<h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
|
78 |
-
<p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
|
79 |
-
</div>
|
80 |
-
</div>
|
81 |
""",
|
|
|
82 |
theme=theme,
|
83 |
)
|
84 |
|
|
|
85 |
if __name__ == "__main__":
|
86 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
|
4 |
from src.distilabel_dataset_generator.apps.faq import app as faq_app
|
5 |
from src.distilabel_dataset_generator.apps.sft import app as sft_app
|
6 |
from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
|
|
|
24 |
background-color: black;
|
25 |
}
|
26 |
}
|
27 |
+
button[role="tab"].selected,
|
28 |
+
button[role="tab"][aria-selected="true"],
|
29 |
+
button[role="tab"][data-tab-id][aria-selected="true"] {
|
30 |
+
background-color: #000000;
|
31 |
+
color: white;
|
32 |
+
border: none;
|
33 |
+
font-size: 16px;
|
34 |
+
font-weight: bold;
|
35 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
36 |
+
transition: background-color 0.3s ease, color 0.3s ease;
|
37 |
+
}
|
38 |
+
.gallery {
|
39 |
+
color: black !important;
|
40 |
+
}
|
41 |
+
.flex-shrink-0.truncate.px-1 {
|
42 |
+
color: black !important;
|
43 |
+
}
|
44 |
"""
|
45 |
|
46 |
+
demo = TabbedInterface(
|
47 |
[textcat_app, sft_app, faq_app],
|
48 |
["Text Classification", "Supervised Fine-Tuning", "FAQ"],
|
49 |
css=css,
|
50 |
title="""
|
51 |
+
<h1>Synthetic Data Generator</h1>
|
52 |
+
<h3>Build datasets using natural language</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
""",
|
54 |
+
head="Synthetic Data Generator",
|
55 |
theme=theme,
|
56 |
)
|
57 |
|
58 |
+
|
59 |
if __name__ == "__main__":
|
60 |
demo.launch()
|
demo.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
|
4 |
+
from src.distilabel_dataset_generator.apps.eval import app as eval_app
|
5 |
+
from src.distilabel_dataset_generator.apps.faq import app as faq_app
|
6 |
+
from src.distilabel_dataset_generator.apps.sft import app as sft_app
|
7 |
+
from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
|
8 |
+
|
9 |
+
theme = gr.themes.Monochrome(
|
10 |
+
spacing_size="md",
|
11 |
+
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
|
12 |
+
)
|
13 |
+
|
14 |
+
css = """
|
15 |
+
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
16 |
+
.tabitem{border: 0px}
|
17 |
+
.group_padding{padding: .55em}
|
18 |
+
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
|
19 |
+
#system_prompt_examples {
|
20 |
+
color: black;
|
21 |
+
}
|
22 |
+
@media (prefers-color-scheme: dark) {
|
23 |
+
#system_prompt_examples {
|
24 |
+
color: white;
|
25 |
+
background-color: black;
|
26 |
+
}
|
27 |
+
}
|
28 |
+
button[role="tab"].selected,
|
29 |
+
button[role="tab"][aria-selected="true"],
|
30 |
+
button[role="tab"][data-tab-id][aria-selected="true"] {
|
31 |
+
background-color: #000000;
|
32 |
+
color: white;
|
33 |
+
border: none;
|
34 |
+
font-size: 16px;
|
35 |
+
font-weight: bold;
|
36 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
37 |
+
transition: background-color 0.3s ease, color 0.3s ease;
|
38 |
+
}
|
39 |
+
.gallery {
|
40 |
+
color: black !important;
|
41 |
+
}
|
42 |
+
.flex-shrink-0.truncate.px-1 {
|
43 |
+
color: black !important;
|
44 |
+
}
|
45 |
+
"""
|
46 |
+
|
47 |
+
demo = TabbedInterface(
|
48 |
+
[textcat_app, sft_app, eval_app, faq_app],
|
49 |
+
["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
|
50 |
+
css=css,
|
51 |
+
title="""
|
52 |
+
<h1>Synthetic Data Generator</h1>
|
53 |
+
<h3>Build datasets using natural language</h3>
|
54 |
+
""",
|
55 |
+
head="Synthetic Data Generator",
|
56 |
+
theme=theme,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
demo.launch()
|
pdm.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -6,11 +6,13 @@ authors = [
|
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
7 |
]
|
8 |
dependencies = [
|
9 |
-
"distilabel[hf-inference-endpoints,argilla]>=1.4.1",
|
10 |
-
"gradio[oauth]
|
11 |
"transformers>=4.44.2",
|
12 |
"sentence-transformers>=3.2.0",
|
13 |
"model2vec>=0.2.4",
|
|
|
|
|
14 |
]
|
15 |
requires-python = "<3.13,>=3.10"
|
16 |
readme = "README.md"
|
|
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
7 |
]
|
8 |
dependencies = [
|
9 |
+
"distilabel[hf-inference-endpoints,argilla,outlines]>=1.4.1",
|
10 |
+
"gradio[oauth]<5.0.0",
|
11 |
"transformers>=4.44.2",
|
12 |
"sentence-transformers>=3.2.0",
|
13 |
"model2vec>=0.2.4",
|
14 |
+
"gradio-huggingfacehub-search>=0.0.7",
|
15 |
+
"argilla>=2.4.0",
|
16 |
]
|
17 |
requires-python = "<3.13,>=3.10"
|
18 |
readme = "README.md"
|
requirements.txt
CHANGED
@@ -1,7 +1,148 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is @generated by PDM.
|
2 |
+
# Please do not edit it manually.
|
3 |
+
|
4 |
+
aiofiles==23.2.1
|
5 |
+
aiohappyeyeballs==2.4.3
|
6 |
+
aiohttp==3.11.7
|
7 |
+
aiosignal==1.3.1
|
8 |
+
airportsdata==20241001
|
9 |
+
annotated-types==0.7.0
|
10 |
+
anyio==4.6.2.post1
|
11 |
+
argilla==2.4.0
|
12 |
+
asttokens==2.4.1
|
13 |
+
async-timeout==5.0.1; python_version < "3.11"
|
14 |
+
attrs==24.2.0
|
15 |
+
authlib==1.3.2
|
16 |
+
certifi==2024.8.30
|
17 |
+
cffi==1.17.1; platform_python_implementation != "PyPy"
|
18 |
+
charset-normalizer==3.4.0
|
19 |
+
click==8.1.7
|
20 |
+
cloudpickle==3.1.0
|
21 |
+
colorama==0.4.6; platform_system == "Windows" or sys_platform == "win32"
|
22 |
+
contourpy==1.3.1
|
23 |
+
cryptography==43.0.3
|
24 |
+
cycler==0.12.1
|
25 |
+
datasets==3.1.0
|
26 |
+
decorator==5.1.1
|
27 |
+
dill==0.3.8
|
28 |
+
diskcache==5.6.3
|
29 |
+
distilabel==1.4.1
|
30 |
+
distilabel[argilla,hf-inference-endpoints,outlines]==1.4.1
|
31 |
+
exceptiongroup==1.2.2; python_version < "3.11"
|
32 |
+
executing==2.1.0
|
33 |
+
fastapi==0.115.5
|
34 |
+
ffmpy==0.4.0
|
35 |
+
filelock==3.16.1
|
36 |
+
fonttools==4.55.0
|
37 |
+
frozenlist==1.5.0
|
38 |
+
fsspec==2024.9.0
|
39 |
+
fsspec[http]==2024.9.0
|
40 |
+
gradio==4.44.1
|
41 |
+
gradio-client==1.3.0
|
42 |
+
gradio-huggingfacehub-search==0.0.7
|
43 |
+
gradio[oauth]==4.44.1
|
44 |
+
h11==0.14.0
|
45 |
+
httpcore==1.0.7
|
46 |
+
httpx==0.27.2
|
47 |
+
huggingface-hub==0.26.2
|
48 |
+
idna==3.10
|
49 |
+
importlib-resources==6.4.5
|
50 |
+
interegular==0.3.3
|
51 |
+
ipython==8.29.0
|
52 |
+
itsdangerous==2.2.0
|
53 |
+
jedi==0.19.2
|
54 |
+
jinja2==3.1.4
|
55 |
+
joblib==1.4.2
|
56 |
+
jsonschema==4.23.0
|
57 |
+
jsonschema-specifications==2024.10.1
|
58 |
+
kiwisolver==1.4.7
|
59 |
+
lark==1.2.2
|
60 |
+
llvmlite==0.43.0
|
61 |
+
markdown-it-py==3.0.0
|
62 |
+
markupsafe==2.1.5
|
63 |
+
matplotlib==3.9.2
|
64 |
+
matplotlib-inline==0.1.7
|
65 |
+
mdurl==0.1.2
|
66 |
+
model2vec==0.3.3
|
67 |
+
mpmath==1.3.0; python_version >= "3.9"
|
68 |
+
multidict==6.1.0
|
69 |
+
multiprocess==0.70.16
|
70 |
+
nest-asyncio==1.6.0
|
71 |
+
networkx==3.4.2
|
72 |
+
numba==0.60.0
|
73 |
+
numpy==1.26.4
|
74 |
+
nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64"
|
75 |
+
nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
|
76 |
+
nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
|
77 |
+
nvidia-cuda-runtime-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
|
78 |
+
nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64"
|
79 |
+
nvidia-cufft-cu12==11.2.1.3; platform_system == "Linux" and platform_machine == "x86_64"
|
80 |
+
nvidia-curand-cu12==10.3.5.147; platform_system == "Linux" and platform_machine == "x86_64"
|
81 |
+
nvidia-cusolver-cu12==11.6.1.9; platform_system == "Linux" and platform_machine == "x86_64"
|
82 |
+
nvidia-cusparse-cu12==12.3.1.170; platform_system == "Linux" and platform_machine == "x86_64"
|
83 |
+
nvidia-nccl-cu12==2.21.5; platform_system == "Linux" and platform_machine == "x86_64"
|
84 |
+
nvidia-nvjitlink-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
|
85 |
+
nvidia-nvtx-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
|
86 |
+
orjson==3.10.11
|
87 |
+
outlines==0.1.4
|
88 |
+
outlines-core==0.1.17
|
89 |
+
packaging==24.2
|
90 |
+
pandas==2.2.3
|
91 |
+
parso==0.8.4
|
92 |
+
pexpect==4.9.0; sys_platform != "win32" and sys_platform != "emscripten"
|
93 |
+
pillow==10.4.0
|
94 |
+
portalocker==3.0.0
|
95 |
+
prompt-toolkit==3.0.48
|
96 |
+
propcache==0.2.0
|
97 |
+
ptyprocess==0.7.0; sys_platform != "win32" and sys_platform != "emscripten"
|
98 |
+
pure-eval==0.2.3
|
99 |
+
pyarrow==18.0.0
|
100 |
+
pycountry==24.6.1
|
101 |
+
pycparser==2.22; platform_python_implementation != "PyPy"
|
102 |
+
pydantic==2.10.0
|
103 |
+
pydantic-core==2.27.0
|
104 |
+
pydub==0.25.1
|
105 |
+
pygments==2.18.0
|
106 |
+
pyparsing==3.2.0
|
107 |
+
python-dateutil==2.9.0.post0
|
108 |
+
python-multipart==0.0.17
|
109 |
+
pytz==2024.2
|
110 |
+
pywin32==308; platform_system == "Windows"
|
111 |
+
pyyaml==6.0.2
|
112 |
+
referencing==0.35.1
|
113 |
+
regex==2024.11.6
|
114 |
+
requests==2.32.3
|
115 |
+
rich==13.9.4
|
116 |
+
rpds-py==0.21.0
|
117 |
+
ruff==0.7.4; sys_platform != "emscripten"
|
118 |
+
safetensors==0.4.5
|
119 |
+
scikit-learn==1.5.2
|
120 |
+
scipy==1.14.1
|
121 |
+
semantic-version==2.10.0
|
122 |
+
sentence-transformers==3.3.1
|
123 |
+
setuptools==75.6.0
|
124 |
+
shellingham==1.5.4
|
125 |
+
six==1.16.0
|
126 |
+
sniffio==1.3.1
|
127 |
+
stack-data==0.6.3
|
128 |
+
starlette==0.41.3
|
129 |
+
sympy==1.13.1; python_version >= "3.9"
|
130 |
+
tblib==3.0.0
|
131 |
+
threadpoolctl==3.5.0
|
132 |
+
tokenizers==0.20.3
|
133 |
+
tomlkit==0.12.0
|
134 |
+
torch==2.5.1
|
135 |
+
tqdm==4.67.0
|
136 |
+
traitlets==5.14.3
|
137 |
+
transformers==4.46.3
|
138 |
+
triton==3.1.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13"
|
139 |
+
typer==0.13.1
|
140 |
+
typing-extensions==4.12.2
|
141 |
+
tzdata==2024.2
|
142 |
+
universal-pathlib==0.2.5
|
143 |
+
urllib3==2.2.3
|
144 |
+
uvicorn==0.32.1; sys_platform != "emscripten"
|
145 |
+
wcwidth==0.2.13
|
146 |
+
websockets==12.0
|
147 |
+
xxhash==3.5.0
|
148 |
+
yarl==1.18.0
|
src/distilabel_dataset_generator/_tabbedinterface.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines two useful high-level abstractions to build Gradio apps: Interface and TabbedInterface.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
from collections.abc import Sequence
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from gradio.blocks import Blocks
|
11 |
+
from gradio.components import HTML
|
12 |
+
from gradio.layouts import Tab, Tabs
|
13 |
+
from gradio.themes import ThemeClass as Theme
|
14 |
+
from gradio_client.documentation import document
|
15 |
+
|
16 |
+
|
17 |
+
@document()
|
18 |
+
class TabbedInterface(Blocks):
|
19 |
+
"""
|
20 |
+
A TabbedInterface is created by providing a list of Interfaces or Blocks, each of which gets
|
21 |
+
rendered in a separate tab. Only the components from the Interface/Blocks will be rendered in the tab.
|
22 |
+
Certain high-level attributes of the Blocks (e.g. custom `css`, `js`, and `head` attributes) will not be loaded.
|
23 |
+
|
24 |
+
Demos: tabbed_interface_lite
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
interface_list: Sequence[Blocks],
|
30 |
+
tab_names: list[str] | None = None,
|
31 |
+
title: str | None = None,
|
32 |
+
theme: Theme | str | None = None,
|
33 |
+
analytics_enabled: bool | None = None,
|
34 |
+
css: str | None = None,
|
35 |
+
js: str | None = None,
|
36 |
+
head: str | None = None,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Parameters:
|
40 |
+
interface_list: A list of Interfaces (or Blocks) to be rendered in the tabs.
|
41 |
+
tab_names: A list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
42 |
+
title: The tab title to display when this demo is opened in a browser window.
|
43 |
+
theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
|
44 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
45 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
46 |
+
js: Custom js as a string or path to a js file. The custom js should in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
|
47 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
|
48 |
+
Returns:
|
49 |
+
a Gradio Tabbed Interface for the given interfaces
|
50 |
+
"""
|
51 |
+
super().__init__(
|
52 |
+
title=title or "Gradio",
|
53 |
+
theme=theme,
|
54 |
+
analytics_enabled=analytics_enabled,
|
55 |
+
mode="tabbed_interface",
|
56 |
+
css=css,
|
57 |
+
js=js,
|
58 |
+
head=head,
|
59 |
+
)
|
60 |
+
if tab_names is None:
|
61 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
62 |
+
with self:
|
63 |
+
if title:
|
64 |
+
HTML(value=title)
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column(scale=1):
|
67 |
+
gr.LoginButton(value="Sign in!", size="sm", scale=2)
|
68 |
+
with gr.Column(scale=3):
|
69 |
+
pass
|
70 |
+
with Tabs():
|
71 |
+
for interface, tab_name in zip(interface_list, tab_names, strict=False):
|
72 |
+
with Tab(label=tab_name):
|
73 |
+
interface.render()
|
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -168,8 +168,7 @@ def get_main_ui(
|
|
168 |
|
169 |
def validate_argilla_user_workspace_dataset(
|
170 |
dataset_name: str,
|
171 |
-
|
172 |
-
add_to_existing_dataset: bool,
|
173 |
oauth_token: Union[OAuthToken, None] = None,
|
174 |
progress=gr.Progress(),
|
175 |
) -> str:
|
@@ -193,7 +192,7 @@ def validate_argilla_user_workspace_dataset(
|
|
193 |
dataset = client.datasets(name=dataset_name, workspace=hf_user)
|
194 |
if dataset and not add_to_existing_dataset:
|
195 |
raise gr.Error(f"Dataset {dataset_name} already exists")
|
196 |
-
return
|
197 |
|
198 |
|
199 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
@@ -302,7 +301,8 @@ def get_iterate_on_sample_dataset_ui(
|
|
302 |
|
303 |
|
304 |
def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
|
305 |
-
gr.Markdown("##
|
|
|
306 |
gr.Markdown(
|
307 |
"You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
|
308 |
)
|
@@ -400,7 +400,7 @@ def push_pipeline_code_to_hub(
|
|
400 |
oauth_token: Union[OAuthToken, None] = None,
|
401 |
progress=gr.Progress(),
|
402 |
):
|
403 |
-
repo_id =
|
404 |
progress(0.1, desc="Uploading pipeline code")
|
405 |
with io.BytesIO(pipeline_code.encode("utf-8")) as f:
|
406 |
upload_file(
|
@@ -427,7 +427,7 @@ def push_dataset_to_hub(
|
|
427 |
task: str = TEXTCAT_TASK,
|
428 |
) -> pd.DataFrame:
|
429 |
progress(0.1, desc="Setting up dataset")
|
430 |
-
repo_id =
|
431 |
|
432 |
if task == TEXTCAT_TASK:
|
433 |
if num_labels == 1:
|
@@ -459,7 +459,7 @@ def push_dataset_to_hub(
|
|
459 |
return dataframe
|
460 |
|
461 |
|
462 |
-
def
|
463 |
repo_id = (
|
464 |
f"{org_name}/{repo_name}"
|
465 |
if repo_name is not None and org_name is not None
|
@@ -491,7 +491,7 @@ def get_success_message_row() -> gr.Markdown:
|
|
491 |
return success_message
|
492 |
|
493 |
|
494 |
-
def
|
495 |
client = get_argilla_client()
|
496 |
argilla_api_url = client.api_url
|
497 |
return gr.Markdown(
|
@@ -499,7 +499,13 @@ def show_success_message_argilla() -> gr.Markdown:
|
|
499 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
500 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
501 |
<p style="margin-top: 0.5em;">
|
502 |
-
Your dataset is now available
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
<a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
504 |
{argilla_api_url}
|
505 |
</a>
|
@@ -513,23 +519,5 @@ def show_success_message_argilla() -> gr.Markdown:
|
|
513 |
)
|
514 |
|
515 |
|
516 |
-
def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
|
517 |
-
return gr.Markdown(
|
518 |
-
value=f"""
|
519 |
-
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
520 |
-
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
521 |
-
<p style="margin-top: 0.5em;">
|
522 |
-
The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
|
523 |
-
Your dataset is now available at:
|
524 |
-
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
525 |
-
https://huggingface.co/datasets/{org_name}/{repo_name}
|
526 |
-
</a>
|
527 |
-
</p>
|
528 |
-
</div>
|
529 |
-
""",
|
530 |
-
visible=True,
|
531 |
-
)
|
532 |
-
|
533 |
-
|
534 |
def hide_success_message() -> gr.Markdown:
|
535 |
-
return gr.Markdown(
|
|
|
168 |
|
169 |
def validate_argilla_user_workspace_dataset(
|
170 |
dataset_name: str,
|
171 |
+
add_to_existing_dataset: bool = True,
|
|
|
172 |
oauth_token: Union[OAuthToken, None] = None,
|
173 |
progress=gr.Progress(),
|
174 |
) -> str:
|
|
|
192 |
dataset = client.datasets(name=dataset_name, workspace=hf_user)
|
193 |
if dataset and not add_to_existing_dataset:
|
194 |
raise gr.Error(f"Dataset {dataset_name} already exists")
|
195 |
+
return ""
|
196 |
|
197 |
|
198 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
|
|
301 |
|
302 |
|
303 |
def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
|
304 |
+
gr.Markdown("## Customize and run locally with distilabel")
|
305 |
+
gr.HTML("<hr>")
|
306 |
gr.Markdown(
|
307 |
"You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
|
308 |
)
|
|
|
400 |
oauth_token: Union[OAuthToken, None] = None,
|
401 |
progress=gr.Progress(),
|
402 |
):
|
403 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
404 |
progress(0.1, desc="Uploading pipeline code")
|
405 |
with io.BytesIO(pipeline_code.encode("utf-8")) as f:
|
406 |
upload_file(
|
|
|
427 |
task: str = TEXTCAT_TASK,
|
428 |
) -> pd.DataFrame:
|
429 |
progress(0.1, desc="Setting up dataset")
|
430 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
431 |
|
432 |
if task == TEXTCAT_TASK:
|
433 |
if num_labels == 1:
|
|
|
459 |
return dataframe
|
460 |
|
461 |
|
462 |
+
def validate_push_to_hub(org_name, repo_name):
|
463 |
repo_id = (
|
464 |
f"{org_name}/{repo_name}"
|
465 |
if repo_name is not None and org_name is not None
|
|
|
491 |
return success_message
|
492 |
|
493 |
|
494 |
+
def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
|
495 |
client = get_argilla_client()
|
496 |
argilla_api_url = client.api_url
|
497 |
return gr.Markdown(
|
|
|
499 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
500 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
501 |
<p style="margin-top: 0.5em;">
|
502 |
+
Your dataset is now available the Hugging Face Hub:
|
503 |
+
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
504 |
+
https://huggingface.co/datasets/{org_name}/{repo_name}
|
505 |
+
</a>
|
506 |
+
</p>
|
507 |
+
<p style="margin-top: 0.5em;">
|
508 |
+
Your dataset is now available within Argilla:
|
509 |
<a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
510 |
{argilla_api_url}
|
511 |
</a>
|
|
|
519 |
)
|
520 |
|
521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
def hide_success_message() -> gr.Markdown:
|
523 |
+
return gr.Markdown(value="")
|
src/distilabel_dataset_generator/apps/eval.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import load_dataset
|
6 |
+
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
7 |
+
|
8 |
+
from src.distilabel_dataset_generator.utils import get_org_dropdown
|
9 |
+
|
10 |
+
|
11 |
+
def get_iframe(hub_repo_id) -> str:
|
12 |
+
if not hub_repo_id:
|
13 |
+
raise gr.Error("Hub repo id is required")
|
14 |
+
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
|
15 |
+
iframe = f"""
|
16 |
+
<iframe
|
17 |
+
src="{url}"
|
18 |
+
frameborder="0"
|
19 |
+
width="100%"
|
20 |
+
height="600px"
|
21 |
+
></iframe>
|
22 |
+
"""
|
23 |
+
return iframe
|
24 |
+
|
25 |
+
|
26 |
+
def get_valid_columns(df: pd.DataFrame):
|
27 |
+
valid_columns = []
|
28 |
+
for col in df.columns:
|
29 |
+
sample_val = df[col].iloc[0]
|
30 |
+
if isinstance(sample_val, str) or (
|
31 |
+
isinstance(sample_val, list)
|
32 |
+
and all(isinstance(item, dict) for item in sample_val)
|
33 |
+
):
|
34 |
+
valid_columns.append(col)
|
35 |
+
return valid_columns
|
36 |
+
|
37 |
+
|
38 |
+
def load_dataset_from_hub(hub_repo_id: str, n_rows: int = 10):
|
39 |
+
gr.Info(message="Loading dataset ...")
|
40 |
+
if not hub_repo_id:
|
41 |
+
raise gr.Error("Hub repo id is required")
|
42 |
+
ds_dict = load_dataset(hub_repo_id)
|
43 |
+
splits = list(ds_dict.keys())
|
44 |
+
ds = ds_dict[splits[0]]
|
45 |
+
if n_rows:
|
46 |
+
ds = ds.select(range(n_rows))
|
47 |
+
df = ds.to_pandas()
|
48 |
+
# Get columns that contain either strings or lists of dictionaries
|
49 |
+
valid_columns = get_valid_columns(df)
|
50 |
+
return (
|
51 |
+
df,
|
52 |
+
gr.Dropdown(choices=valid_columns, label="Instruction Column"),
|
53 |
+
gr.Dropdown(choices=valid_columns, label="Instruction Column"),
|
54 |
+
gr.Dropdown(choices=valid_columns, label="Response Column"),
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def define_evaluation_aspects(task_type: str):
|
59 |
+
if task_type == "instruction":
|
60 |
+
return gr.Dropdown(
|
61 |
+
value=["overall-rating"],
|
62 |
+
choices=["complexity", "quality"],
|
63 |
+
label="Evaluation Aspects",
|
64 |
+
multiselect=True,
|
65 |
+
interactive=True,
|
66 |
+
)
|
67 |
+
elif task_type == "instruction-response":
|
68 |
+
return gr.Dropdown(
|
69 |
+
value=["overall-rating"],
|
70 |
+
choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
|
71 |
+
label="Evaluation Aspects",
|
72 |
+
multiselect=True,
|
73 |
+
interactive=True,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
return gr.Dropdown(interactive=False)
|
77 |
+
|
78 |
+
|
79 |
+
def evaluate_instruction(df: pd.DataFrame, aspects: list[str], instruction_column: str):
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
def evaluate_instruction_response(
|
84 |
+
df: pd.DataFrame, aspects: list[str], instruction_column: str, response_column: str
|
85 |
+
):
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
def evaluate_custom(
|
90 |
+
df: pd.DataFrame, aspects: list[str], prompt_template: str, structured_output: dict
|
91 |
+
):
|
92 |
+
pass
|
93 |
+
|
94 |
+
|
95 |
+
def _apply_to_dataset(
|
96 |
+
df: pd.DataFrame,
|
97 |
+
eval_type: str,
|
98 |
+
aspects_instruction: list[str],
|
99 |
+
instruction_column: str,
|
100 |
+
aspects_instruction_response: list[str],
|
101 |
+
instruction_column_response: str,
|
102 |
+
response_column_response: str,
|
103 |
+
aspects_custom: list[str],
|
104 |
+
prompt_template: str,
|
105 |
+
structured_output: dict,
|
106 |
+
):
|
107 |
+
if eval_type == "instruction":
|
108 |
+
df = evaluate_instruction(df, aspects_instruction, instruction_column)
|
109 |
+
elif eval_type == "instruction-response":
|
110 |
+
df = evaluate_instruction_response(
|
111 |
+
df,
|
112 |
+
aspects_instruction_response,
|
113 |
+
instruction_column_response,
|
114 |
+
response_column_response,
|
115 |
+
)
|
116 |
+
elif eval_type == "custom":
|
117 |
+
df = evaluate_custom(df, aspects_custom, prompt_template, structured_output)
|
118 |
+
return df
|
119 |
+
|
120 |
+
|
121 |
+
def apply_to_sample_dataset(
|
122 |
+
repo_id: str,
|
123 |
+
eval_type: str,
|
124 |
+
aspects_instruction: list[str],
|
125 |
+
aspects_instruction_response: list[str],
|
126 |
+
aspects_custom: list[str],
|
127 |
+
instruction_instruction: str,
|
128 |
+
instruction_instruction_response: str,
|
129 |
+
response_instruction_response: str,
|
130 |
+
prompt_template: str,
|
131 |
+
structured_output: dict,
|
132 |
+
):
|
133 |
+
df, _, _, _ = load_dataset_from_hub(repo_id, n_rows=10)
|
134 |
+
df = _apply_to_dataset(
|
135 |
+
df,
|
136 |
+
eval_type,
|
137 |
+
aspects_instruction,
|
138 |
+
instruction_instruction,
|
139 |
+
aspects_instruction_response,
|
140 |
+
instruction_instruction_response,
|
141 |
+
response_instruction_response,
|
142 |
+
aspects_custom,
|
143 |
+
prompt_template,
|
144 |
+
structured_output,
|
145 |
+
)
|
146 |
+
return df
|
147 |
+
|
148 |
+
|
149 |
+
def push_to_hub(
|
150 |
+
org_name: str,
|
151 |
+
repo_name: str,
|
152 |
+
private: bool,
|
153 |
+
n_rows: int,
|
154 |
+
original_repo_id: str,
|
155 |
+
eval_type: str,
|
156 |
+
aspects_instruction: list[str],
|
157 |
+
aspects_instruction_response: list[str],
|
158 |
+
aspects_custom: list[str],
|
159 |
+
instruction_instruction: str,
|
160 |
+
instruction_instruction_response: str,
|
161 |
+
response_instruction_response: str,
|
162 |
+
prompt_template: str,
|
163 |
+
structured_output: dict,
|
164 |
+
):
|
165 |
+
df, _, _, _ = load_dataset_from_hub(original_repo_id, n_rows=n_rows)
|
166 |
+
df = _apply_to_dataset(
|
167 |
+
df,
|
168 |
+
eval_type,
|
169 |
+
aspects_instruction,
|
170 |
+
instruction_instruction,
|
171 |
+
aspects_instruction_response,
|
172 |
+
instruction_instruction_response,
|
173 |
+
response_instruction_response,
|
174 |
+
aspects_custom,
|
175 |
+
prompt_template,
|
176 |
+
structured_output,
|
177 |
+
)
|
178 |
+
new_repo_id = f"{org_name}/{repo_name}"
|
179 |
+
print(df)
|
180 |
+
|
181 |
+
|
182 |
+
with gr.Blocks() as app:
|
183 |
+
gr.Markdown("## Select your input dataset")
|
184 |
+
gr.HTML("<hr>")
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column(scale=1):
|
187 |
+
search_in = HuggingfaceHubSearch(
|
188 |
+
label="Search",
|
189 |
+
placeholder="Search for a Dataset",
|
190 |
+
search_type="dataset",
|
191 |
+
sumbit_on_select=True,
|
192 |
+
)
|
193 |
+
load_btn = gr.Button("Load Dataset")
|
194 |
+
with gr.Column(scale=3):
|
195 |
+
search_out = gr.HTML(label="Dataset Preview")
|
196 |
+
|
197 |
+
gr.Markdown("## Configure your task")
|
198 |
+
gr.HTML("<hr>")
|
199 |
+
with gr.Row():
|
200 |
+
with gr.Column(scale=1):
|
201 |
+
eval_type = gr.Dropdown(
|
202 |
+
label="Evaluation Type",
|
203 |
+
choices=["instruction", "instruction-response", "custom"],
|
204 |
+
visible=False,
|
205 |
+
)
|
206 |
+
with gr.Tab("instruction") as tab_instruction:
|
207 |
+
aspects_instruction = define_evaluation_aspects("instruction")
|
208 |
+
instruction_instruction = gr.Dropdown(
|
209 |
+
label="Instruction Column", interactive=True
|
210 |
+
)
|
211 |
+
tab_instruction.select(
|
212 |
+
lambda: "instruction",
|
213 |
+
inputs=[],
|
214 |
+
outputs=[eval_type],
|
215 |
+
)
|
216 |
+
with gr.Tab("instruction-response") as tab_instruction_response:
|
217 |
+
aspects_instruction_response = define_evaluation_aspects(
|
218 |
+
"instruction-response"
|
219 |
+
)
|
220 |
+
instruction_instruction_response = gr.Dropdown(
|
221 |
+
label="Instruction Column", interactive=True
|
222 |
+
)
|
223 |
+
response_instruction_response = gr.Dropdown(
|
224 |
+
label="Response Column", interactive=True
|
225 |
+
)
|
226 |
+
tab_instruction_response.select(
|
227 |
+
lambda: "instruction-response",
|
228 |
+
inputs=[],
|
229 |
+
outputs=[eval_type],
|
230 |
+
)
|
231 |
+
with gr.Tab("custom") as tab_custom:
|
232 |
+
aspects_custom = define_evaluation_aspects("custom")
|
233 |
+
prompt_template = gr.Code(
|
234 |
+
label="Prompt Template",
|
235 |
+
value="{{column_1}} based on {{column_2}}",
|
236 |
+
language="markdown",
|
237 |
+
interactive=True,
|
238 |
+
)
|
239 |
+
structured_output = gr.Code(
|
240 |
+
label="Structured Output",
|
241 |
+
value=json.dumps({"eval_aspect": "str"}),
|
242 |
+
language="json",
|
243 |
+
interactive=True,
|
244 |
+
)
|
245 |
+
tab_custom.select(
|
246 |
+
lambda: "custom",
|
247 |
+
inputs=[],
|
248 |
+
outputs=[eval_type],
|
249 |
+
)
|
250 |
+
btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
|
251 |
+
with gr.Column(scale=3):
|
252 |
+
dataframe = gr.Dataframe()
|
253 |
+
|
254 |
+
gr.Markdown("## Generate your dataset")
|
255 |
+
gr.HTML("<hr>")
|
256 |
+
with gr.Row():
|
257 |
+
with gr.Column(scale=1):
|
258 |
+
org_name = get_org_dropdown()
|
259 |
+
repo_name = gr.Textbox(
|
260 |
+
label="Repo name",
|
261 |
+
placeholder="dataset_name",
|
262 |
+
value="my-distiset",
|
263 |
+
interactive=True,
|
264 |
+
)
|
265 |
+
n_rows = gr.Number(
|
266 |
+
label="Number of rows",
|
267 |
+
value=10,
|
268 |
+
interactive=True,
|
269 |
+
scale=1,
|
270 |
+
)
|
271 |
+
private = gr.Checkbox(
|
272 |
+
label="Private dataset",
|
273 |
+
value=False,
|
274 |
+
interactive=True,
|
275 |
+
scale=1,
|
276 |
+
)
|
277 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
278 |
+
with gr.Column(scale=3):
|
279 |
+
success_message = gr.Markdown(visible=False)
|
280 |
+
|
281 |
+
search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
|
282 |
+
load_btn.click(
|
283 |
+
load_dataset_from_hub,
|
284 |
+
inputs=[search_in],
|
285 |
+
outputs=[
|
286 |
+
dataframe,
|
287 |
+
instruction_instruction,
|
288 |
+
instruction_instruction_response,
|
289 |
+
response_instruction_response,
|
290 |
+
],
|
291 |
+
)
|
292 |
+
btn_apply_to_sample_dataset.click(
|
293 |
+
apply_to_sample_dataset,
|
294 |
+
inputs=[
|
295 |
+
search_in,
|
296 |
+
eval_type,
|
297 |
+
aspects_instruction,
|
298 |
+
aspects_instruction_response,
|
299 |
+
aspects_custom,
|
300 |
+
instruction_instruction,
|
301 |
+
instruction_instruction_response,
|
302 |
+
response_instruction_response,
|
303 |
+
prompt_template,
|
304 |
+
structured_output,
|
305 |
+
],
|
306 |
+
outputs=dataframe,
|
307 |
+
)
|
308 |
+
btn_push_to_hub.click(
|
309 |
+
push_to_hub,
|
310 |
+
inputs=[
|
311 |
+
org_name,
|
312 |
+
repo_name,
|
313 |
+
private,
|
314 |
+
n_rows,
|
315 |
+
search_in,
|
316 |
+
eval_type,
|
317 |
+
aspects_instruction,
|
318 |
+
aspects_instruction_response,
|
319 |
+
aspects_custom,
|
320 |
+
instruction_instruction,
|
321 |
+
instruction_instruction_response,
|
322 |
+
response_instruction_response,
|
323 |
+
prompt_template,
|
324 |
+
structured_output,
|
325 |
+
],
|
326 |
+
outputs=success_message,
|
327 |
+
)
|
328 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import ast
|
|
|
2 |
from typing import Dict, List, Union
|
3 |
|
4 |
import argilla as rg
|
@@ -10,16 +11,11 @@ from huggingface_hub import HfApi
|
|
10 |
|
11 |
from src.distilabel_dataset_generator.apps.base import (
|
12 |
get_argilla_client,
|
13 |
-
get_main_ui,
|
14 |
get_pipeline_code_ui,
|
15 |
hide_success_message,
|
16 |
-
push_pipeline_code_to_hub,
|
17 |
-
show_success_message_argilla,
|
18 |
show_success_message_hub,
|
19 |
validate_argilla_user_workspace_dataset,
|
20 |
-
|
21 |
-
from src.distilabel_dataset_generator.apps.base import (
|
22 |
-
push_dataset_to_hub as push_to_hub_base,
|
23 |
)
|
24 |
from src.distilabel_dataset_generator.pipelines.base import (
|
25 |
DEFAULT_BATCH_SIZE,
|
@@ -30,16 +26,15 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
|
|
30 |
)
|
31 |
from src.distilabel_dataset_generator.pipelines.sft import (
|
32 |
DEFAULT_DATASET_DESCRIPTIONS,
|
33 |
-
DEFAULT_DATASETS,
|
34 |
-
DEFAULT_SYSTEM_PROMPTS,
|
35 |
PROMPT_CREATION_PROMPT,
|
36 |
generate_pipeline_code,
|
37 |
get_magpie_generator,
|
38 |
get_prompt_generator,
|
39 |
get_response_generator,
|
40 |
)
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
|
45 |
def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
@@ -57,33 +52,176 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
57 |
return dataframe
|
58 |
|
59 |
|
60 |
-
def
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
progress=gr.Progress(),
|
67 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
original_dataframe = dataframe.copy(deep=True)
|
69 |
dataframe = convert_dataframe_messages(dataframe)
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
return original_dataframe
|
77 |
|
78 |
|
79 |
def push_dataset_to_argilla(
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
83 |
progress=gr.Progress(),
|
84 |
) -> pd.DataFrame:
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
progress(0.1, desc="Setting up user and workspace")
|
89 |
client = get_argilla_client()
|
@@ -185,10 +323,10 @@ def push_dataset_to_argilla(
|
|
185 |
dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
|
186 |
|
187 |
progress(0.5, desc="Creating dataset")
|
188 |
-
rg_dataset = client.datasets(name=
|
189 |
if rg_dataset is None:
|
190 |
rg_dataset = rg.Dataset(
|
191 |
-
name=
|
192 |
workspace=hf_user,
|
193 |
settings=settings,
|
194 |
client=client,
|
@@ -200,309 +338,123 @@ def push_dataset_to_argilla(
|
|
200 |
progress(1.0, desc="Dataset pushed to Argilla")
|
201 |
except Exception as e:
|
202 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
203 |
-
return
|
204 |
-
|
205 |
-
|
206 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
207 |
-
progress(0.0, desc="Generating system prompt")
|
208 |
-
if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
|
209 |
-
index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
|
210 |
-
if index < len(DEFAULT_SYSTEM_PROMPTS):
|
211 |
-
return DEFAULT_SYSTEM_PROMPTS[index]
|
212 |
|
213 |
-
progress(0.3, desc="Initializing text generation")
|
214 |
-
generate_description = get_prompt_generator()
|
215 |
-
progress(0.7, desc="Generating system prompt")
|
216 |
-
result = next(
|
217 |
-
generate_description.process(
|
218 |
-
[
|
219 |
-
{
|
220 |
-
"system_prompt": PROMPT_CREATION_PROMPT,
|
221 |
-
"instruction": dataset_description,
|
222 |
-
}
|
223 |
-
]
|
224 |
-
)
|
225 |
-
)[0]["generation"]
|
226 |
-
progress(1.0, desc="System prompt generated")
|
227 |
-
return result
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
progress(0.0, desc="(1/2) Generating instructions")
|
238 |
-
magpie_generator = get_magpie_generator(
|
239 |
-
num_turns, num_rows, system_prompt, is_sample
|
240 |
-
)
|
241 |
-
response_generator = get_response_generator(num_turns, system_prompt, is_sample)
|
242 |
-
total_steps: int = num_rows * 2
|
243 |
-
batch_size = DEFAULT_BATCH_SIZE
|
244 |
-
|
245 |
-
# create instructions
|
246 |
-
n_processed = 0
|
247 |
-
magpie_results = []
|
248 |
-
while n_processed < num_rows:
|
249 |
-
progress(
|
250 |
-
0.5 * n_processed / num_rows,
|
251 |
-
total=total_steps,
|
252 |
-
desc="(1/2) Generating instructions",
|
253 |
-
)
|
254 |
-
remaining_rows = num_rows - n_processed
|
255 |
-
batch_size = min(batch_size, remaining_rows)
|
256 |
-
inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
|
257 |
-
batch = list(magpie_generator.process(inputs=inputs))
|
258 |
-
magpie_results.extend(batch[0])
|
259 |
-
n_processed += batch_size
|
260 |
-
progress(0.5, desc="(1/2) Generating instructions")
|
261 |
-
|
262 |
-
# generate responses
|
263 |
-
n_processed = 0
|
264 |
-
response_results = []
|
265 |
-
if num_turns == 1:
|
266 |
-
while n_processed < num_rows:
|
267 |
-
progress(
|
268 |
-
0.5 + 0.5 * n_processed / num_rows,
|
269 |
-
total=total_steps,
|
270 |
-
desc="(2/2) Generating responses",
|
271 |
-
)
|
272 |
-
batch = magpie_results[n_processed : n_processed + batch_size]
|
273 |
-
responses = list(response_generator.process(inputs=batch))
|
274 |
-
response_results.extend(responses[0])
|
275 |
-
n_processed += batch_size
|
276 |
-
for result in response_results:
|
277 |
-
result["prompt"] = result["instruction"]
|
278 |
-
result["completion"] = result["generation"]
|
279 |
-
result["system_prompt"] = system_prompt
|
280 |
-
else:
|
281 |
-
for result in magpie_results:
|
282 |
-
result["conversation"].insert(
|
283 |
-
0, {"role": "system", "content": system_prompt}
|
284 |
)
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
desc="(2/2) Generating responses",
|
291 |
)
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
for result in response_results:
|
297 |
-
result["messages"].append(
|
298 |
-
{"role": "assistant", "content": result["generation"]}
|
299 |
)
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
)
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
for result in response_results:
|
309 |
-
record = {}
|
310 |
-
for relevant_keys in [
|
311 |
-
"messages",
|
312 |
-
"prompt",
|
313 |
-
"completion",
|
314 |
-
"model_name",
|
315 |
-
"system_prompt",
|
316 |
-
]:
|
317 |
-
if relevant_keys in result:
|
318 |
-
record[relevant_keys] = result[relevant_keys]
|
319 |
-
distiset_results.append(record)
|
320 |
-
|
321 |
-
distiset = Distiset(
|
322 |
-
{
|
323 |
-
"default": Dataset.from_list(distiset_results),
|
324 |
-
}
|
325 |
-
)
|
326 |
-
|
327 |
-
# If not pushing to hub generate the dataset directly
|
328 |
-
distiset = distiset["default"]
|
329 |
-
if num_turns == 1:
|
330 |
-
outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
|
331 |
-
else:
|
332 |
-
outputs = distiset.to_pandas()[["messages"]]
|
333 |
-
dataframe = pd.DataFrame(outputs)
|
334 |
-
progress(1.0, desc="Dataset generation completed")
|
335 |
-
return dataframe
|
336 |
-
|
337 |
-
|
338 |
-
(
|
339 |
-
app,
|
340 |
-
main_ui,
|
341 |
-
custom_input_ui,
|
342 |
-
dataset_description,
|
343 |
-
examples,
|
344 |
-
btn_generate_system_prompt,
|
345 |
-
system_prompt,
|
346 |
-
sample_dataset,
|
347 |
-
btn_generate_sample_dataset,
|
348 |
-
dataset_name,
|
349 |
-
add_to_existing_dataset,
|
350 |
-
btn_generate_full_dataset_argilla,
|
351 |
-
btn_generate_and_push_to_argilla,
|
352 |
-
btn_push_to_argilla,
|
353 |
-
org_name,
|
354 |
-
repo_name,
|
355 |
-
private,
|
356 |
-
btn_generate_full_dataset,
|
357 |
-
btn_generate_and_push_to_hub,
|
358 |
-
btn_push_to_hub,
|
359 |
-
final_dataset,
|
360 |
-
success_message,
|
361 |
-
) = get_main_ui(
|
362 |
-
default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
|
363 |
-
default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
|
364 |
-
default_datasets=DEFAULT_DATASETS,
|
365 |
-
fn_generate_system_prompt=generate_system_prompt,
|
366 |
-
fn_generate_dataset=generate_dataset,
|
367 |
-
task=TASK,
|
368 |
-
)
|
369 |
-
|
370 |
-
with app:
|
371 |
-
with main_ui:
|
372 |
-
with custom_input_ui:
|
373 |
num_turns = gr.Number(
|
374 |
value=1,
|
375 |
label="Number of turns in the conversation",
|
376 |
minimum=1,
|
377 |
maximum=4,
|
378 |
step=1,
|
|
|
379 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
380 |
)
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
value=10,
|
383 |
-
|
384 |
-
|
385 |
-
maximum=500,
|
386 |
-
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
387 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
|
393 |
-
# define app triggers
|
394 |
gr.on(
|
395 |
-
triggers=[
|
396 |
-
|
397 |
-
|
398 |
-
],
|
399 |
-
|
400 |
-
outputs=[success_message],
|
401 |
).then(
|
402 |
-
fn=
|
403 |
-
inputs=[system_prompt
|
404 |
-
outputs=[
|
405 |
show_progress=True,
|
406 |
)
|
407 |
|
408 |
-
|
409 |
fn=validate_argilla_user_workspace_dataset,
|
410 |
-
inputs=[
|
411 |
-
outputs=[final_dataset],
|
412 |
-
show_progress=True,
|
413 |
-
).success(
|
414 |
-
fn=hide_success_message,
|
415 |
outputs=[success_message],
|
416 |
-
).success(
|
417 |
-
fn=generate_dataset,
|
418 |
-
inputs=[system_prompt, num_turns, num_rows],
|
419 |
-
outputs=[final_dataset],
|
420 |
-
show_progress=True,
|
421 |
-
).success(
|
422 |
-
fn=push_dataset_to_argilla,
|
423 |
-
inputs=[final_dataset, dataset_name],
|
424 |
-
outputs=[final_dataset],
|
425 |
-
show_progress=True,
|
426 |
-
).success(
|
427 |
-
fn=show_success_message_argilla,
|
428 |
-
inputs=[],
|
429 |
-
outputs=[success_message],
|
430 |
-
)
|
431 |
-
|
432 |
-
btn_generate_and_push_to_hub.click(
|
433 |
-
fn=hide_success_message,
|
434 |
-
outputs=[success_message],
|
435 |
-
).then(
|
436 |
-
fn=generate_dataset,
|
437 |
-
inputs=[system_prompt, num_turns, num_rows],
|
438 |
-
outputs=[final_dataset],
|
439 |
-
show_progress=True,
|
440 |
-
).then(
|
441 |
-
fn=push_dataset_to_hub,
|
442 |
-
inputs=[final_dataset, private, org_name, repo_name],
|
443 |
-
outputs=[final_dataset],
|
444 |
show_progress=True,
|
445 |
).then(
|
446 |
-
fn=
|
447 |
-
inputs=[pipeline_code, org_name, repo_name],
|
448 |
-
outputs=[],
|
449 |
-
show_progress=True,
|
450 |
-
).success(
|
451 |
-
fn=show_success_message_hub,
|
452 |
inputs=[org_name, repo_name],
|
453 |
outputs=[success_message],
|
454 |
-
)
|
455 |
-
|
456 |
-
btn_push_to_hub.click(
|
457 |
-
fn=hide_success_message,
|
458 |
-
outputs=[success_message],
|
459 |
-
).then(
|
460 |
-
fn=push_dataset_to_hub,
|
461 |
-
inputs=[final_dataset, private, org_name, repo_name],
|
462 |
-
outputs=[final_dataset],
|
463 |
-
show_progress=True,
|
464 |
-
).then(
|
465 |
-
fn=push_pipeline_code_to_hub,
|
466 |
-
inputs=[pipeline_code, org_name, repo_name],
|
467 |
-
outputs=[],
|
468 |
show_progress=True,
|
469 |
).success(
|
470 |
-
fn=show_success_message_hub,
|
471 |
-
inputs=[org_name, repo_name],
|
472 |
-
outputs=[success_message],
|
473 |
-
)
|
474 |
-
|
475 |
-
btn_push_to_argilla.click(
|
476 |
fn=hide_success_message,
|
477 |
outputs=[success_message],
|
478 |
-
).success(
|
479 |
-
fn=validate_argilla_user_workspace_dataset,
|
480 |
-
inputs=[dataset_name, final_dataset, add_to_existing_dataset],
|
481 |
-
outputs=[final_dataset],
|
482 |
show_progress=True,
|
483 |
).success(
|
484 |
fn=push_dataset_to_argilla,
|
485 |
-
inputs=[
|
486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
show_progress=True,
|
488 |
).success(
|
489 |
-
fn=
|
490 |
-
inputs=[],
|
491 |
outputs=[success_message],
|
492 |
)
|
493 |
-
|
494 |
-
system_prompt.change(
|
495 |
-
fn=generate_pipeline_code,
|
496 |
-
inputs=[system_prompt, num_turns, num_rows],
|
497 |
-
outputs=[pipeline_code],
|
498 |
-
)
|
499 |
-
num_turns.change(
|
500 |
-
fn=generate_pipeline_code,
|
501 |
-
inputs=[system_prompt, num_turns, num_rows],
|
502 |
-
outputs=[pipeline_code],
|
503 |
-
)
|
504 |
-
num_rows.change(
|
505 |
-
fn=generate_pipeline_code,
|
506 |
-
inputs=[system_prompt, num_turns, num_rows],
|
507 |
-
outputs=[pipeline_code],
|
508 |
-
)
|
|
|
1 |
import ast
|
2 |
+
import uuid
|
3 |
from typing import Dict, List, Union
|
4 |
|
5 |
import argilla as rg
|
|
|
11 |
|
12 |
from src.distilabel_dataset_generator.apps.base import (
|
13 |
get_argilla_client,
|
|
|
14 |
get_pipeline_code_ui,
|
15 |
hide_success_message,
|
|
|
|
|
16 |
show_success_message_hub,
|
17 |
validate_argilla_user_workspace_dataset,
|
18 |
+
validate_push_to_hub,
|
|
|
|
|
19 |
)
|
20 |
from src.distilabel_dataset_generator.pipelines.base import (
|
21 |
DEFAULT_BATCH_SIZE,
|
|
|
26 |
)
|
27 |
from src.distilabel_dataset_generator.pipelines.sft import (
|
28 |
DEFAULT_DATASET_DESCRIPTIONS,
|
|
|
|
|
29 |
PROMPT_CREATION_PROMPT,
|
30 |
generate_pipeline_code,
|
31 |
get_magpie_generator,
|
32 |
get_prompt_generator,
|
33 |
get_response_generator,
|
34 |
)
|
35 |
+
from src.distilabel_dataset_generator.utils import (
|
36 |
+
get_org_dropdown,
|
37 |
+
)
|
38 |
|
39 |
|
40 |
def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
|
52 |
return dataframe
|
53 |
|
54 |
|
55 |
+
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
56 |
+
progress(0.0, desc="Generating system prompt")
|
57 |
+
|
58 |
+
progress(0.3, desc="Initializing text generation")
|
59 |
+
generate_description = get_prompt_generator()
|
60 |
+
progress(0.7, desc="Generating system prompt")
|
61 |
+
result = next(
|
62 |
+
generate_description.process(
|
63 |
+
[
|
64 |
+
{
|
65 |
+
"system_prompt": PROMPT_CREATION_PROMPT,
|
66 |
+
"instruction": dataset_description,
|
67 |
+
}
|
68 |
+
]
|
69 |
+
)
|
70 |
+
)[0]["generation"]
|
71 |
+
progress(1.0, desc="System prompt generated")
|
72 |
+
return result, pd.DataFrame()
|
73 |
+
|
74 |
+
|
75 |
+
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
76 |
+
df = generate_dataset(
|
77 |
+
system_prompt=system_prompt,
|
78 |
+
num_turns=1,
|
79 |
+
num_rows=10,
|
80 |
+
progress=progress,
|
81 |
+
is_sample=True,
|
82 |
+
)
|
83 |
+
return df
|
84 |
+
|
85 |
+
|
86 |
+
def generate_dataset(
|
87 |
+
system_prompt: str,
|
88 |
+
num_turns: int = 1,
|
89 |
+
num_rows: int = 10,
|
90 |
+
is_sample: bool = False,
|
91 |
progress=gr.Progress(),
|
92 |
+
) -> pd.DataFrame:
|
93 |
+
progress(0.0, desc="(1/2) Generating instructions")
|
94 |
+
magpie_generator = get_magpie_generator(
|
95 |
+
num_turns, num_rows, system_prompt, is_sample
|
96 |
+
)
|
97 |
+
response_generator = get_response_generator(num_turns, system_prompt, is_sample)
|
98 |
+
total_steps: int = num_rows * 2
|
99 |
+
batch_size = DEFAULT_BATCH_SIZE
|
100 |
+
|
101 |
+
# create instructions
|
102 |
+
n_processed = 0
|
103 |
+
magpie_results = []
|
104 |
+
while n_processed < num_rows:
|
105 |
+
progress(
|
106 |
+
0.5 * n_processed / num_rows,
|
107 |
+
total=total_steps,
|
108 |
+
desc="(1/2) Generating instructions",
|
109 |
+
)
|
110 |
+
remaining_rows = num_rows - n_processed
|
111 |
+
batch_size = min(batch_size, remaining_rows)
|
112 |
+
inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
|
113 |
+
batch = list(magpie_generator.process(inputs=inputs))
|
114 |
+
magpie_results.extend(batch[0])
|
115 |
+
n_processed += batch_size
|
116 |
+
progress(0.5, desc="(1/2) Generating instructions")
|
117 |
+
|
118 |
+
# generate responses
|
119 |
+
n_processed = 0
|
120 |
+
response_results = []
|
121 |
+
if num_turns == 1:
|
122 |
+
while n_processed < num_rows:
|
123 |
+
progress(
|
124 |
+
0.5 + 0.5 * n_processed / num_rows,
|
125 |
+
total=total_steps,
|
126 |
+
desc="(2/2) Generating responses",
|
127 |
+
)
|
128 |
+
batch = magpie_results[n_processed : n_processed + batch_size]
|
129 |
+
responses = list(response_generator.process(inputs=batch))
|
130 |
+
response_results.extend(responses[0])
|
131 |
+
n_processed += batch_size
|
132 |
+
for result in response_results:
|
133 |
+
result["prompt"] = result["instruction"]
|
134 |
+
result["completion"] = result["generation"]
|
135 |
+
result["system_prompt"] = system_prompt
|
136 |
+
else:
|
137 |
+
for result in magpie_results:
|
138 |
+
result["conversation"].insert(
|
139 |
+
0, {"role": "system", "content": system_prompt}
|
140 |
+
)
|
141 |
+
result["messages"] = result["conversation"]
|
142 |
+
while n_processed < num_rows:
|
143 |
+
progress(
|
144 |
+
0.5 + 0.5 * n_processed / num_rows,
|
145 |
+
total=total_steps,
|
146 |
+
desc="(2/2) Generating responses",
|
147 |
+
)
|
148 |
+
batch = magpie_results[n_processed : n_processed + batch_size]
|
149 |
+
responses = list(response_generator.process(inputs=batch))
|
150 |
+
response_results.extend(responses[0])
|
151 |
+
n_processed += batch_size
|
152 |
+
for result in response_results:
|
153 |
+
result["messages"].append(
|
154 |
+
{"role": "assistant", "content": result["generation"]}
|
155 |
+
)
|
156 |
+
progress(
|
157 |
+
1,
|
158 |
+
total=total_steps,
|
159 |
+
desc="(2/2) Creating dataset",
|
160 |
+
)
|
161 |
+
|
162 |
+
# create distiset
|
163 |
+
distiset_results = []
|
164 |
+
for result in response_results:
|
165 |
+
record = {}
|
166 |
+
for relevant_keys in [
|
167 |
+
"messages",
|
168 |
+
"prompt",
|
169 |
+
"completion",
|
170 |
+
"model_name",
|
171 |
+
"system_prompt",
|
172 |
+
]:
|
173 |
+
if relevant_keys in result:
|
174 |
+
record[relevant_keys] = result[relevant_keys]
|
175 |
+
distiset_results.append(record)
|
176 |
+
|
177 |
+
distiset = Distiset(
|
178 |
+
{
|
179 |
+
"default": Dataset.from_list(distiset_results),
|
180 |
+
}
|
181 |
+
)
|
182 |
+
|
183 |
+
# If not pushing to hub generate the dataset directly
|
184 |
+
distiset = distiset["default"]
|
185 |
+
if num_turns == 1:
|
186 |
+
outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]]
|
187 |
+
else:
|
188 |
+
outputs = distiset.to_pandas()[["messages"]]
|
189 |
+
dataframe = pd.DataFrame(outputs)
|
190 |
+
progress(1.0, desc="Dataset generation completed")
|
191 |
+
return dataframe
|
192 |
+
|
193 |
+
|
194 |
+
def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
|
195 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
196 |
original_dataframe = dataframe.copy(deep=True)
|
197 |
dataframe = convert_dataframe_messages(dataframe)
|
198 |
+
distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
|
199 |
+
distiset.push_to_hub(
|
200 |
+
repo_id=repo_id,
|
201 |
+
private=private,
|
202 |
+
include_script=False,
|
203 |
+
token=oauth_token.token,
|
204 |
+
create_pr=False,
|
205 |
+
)
|
206 |
return original_dataframe
|
207 |
|
208 |
|
209 |
def push_dataset_to_argilla(
|
210 |
+
org_name: str,
|
211 |
+
repo_name: str,
|
212 |
+
system_prompt: str,
|
213 |
+
num_turns: int = 1,
|
214 |
+
n_rows: int = 10,
|
215 |
+
private: bool = False,
|
216 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
217 |
progress=gr.Progress(),
|
218 |
) -> pd.DataFrame:
|
219 |
+
dataframe = generate_dataset(
|
220 |
+
system_prompt=system_prompt,
|
221 |
+
num_turns=num_turns,
|
222 |
+
num_rows=n_rows,
|
223 |
+
)
|
224 |
+
push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
|
225 |
try:
|
226 |
progress(0.1, desc="Setting up user and workspace")
|
227 |
client = get_argilla_client()
|
|
|
323 |
dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
|
324 |
|
325 |
progress(0.5, desc="Creating dataset")
|
326 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
327 |
if rg_dataset is None:
|
328 |
rg_dataset = rg.Dataset(
|
329 |
+
name=repo_name,
|
330 |
workspace=hf_user,
|
331 |
settings=settings,
|
332 |
client=client,
|
|
|
338 |
progress(1.0, desc="Dataset pushed to Argilla")
|
339 |
except Exception as e:
|
340 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
341 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
+
with gr.Blocks() as app:
|
345 |
+
gr.Markdown("## Describe the dataset you want")
|
346 |
+
gr.HTML("<hr>")
|
347 |
+
with gr.Row():
|
348 |
+
with gr.Column(scale=1):
|
349 |
+
dataset_description = gr.Textbox(
|
350 |
+
label="Dataset description",
|
351 |
+
placeholder="Give a precise description of your desired dataset.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
)
|
353 |
+
examples = gr.Examples(
|
354 |
+
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
355 |
+
inputs=[dataset_description],
|
356 |
+
cache_examples=False,
|
357 |
+
label="Example descriptions",
|
|
|
358 |
)
|
359 |
+
system_prompt = gr.Textbox(
|
360 |
+
label="System prompt",
|
361 |
+
placeholder="You are a helpful assistant.",
|
362 |
+
visible=False,
|
|
|
|
|
|
|
363 |
)
|
364 |
+
load_btn = gr.Button("Load Dataset")
|
365 |
+
with gr.Column(scale=3):
|
366 |
+
pass
|
367 |
+
|
368 |
+
gr.Markdown("## Configure your task")
|
369 |
+
gr.HTML("<hr>")
|
370 |
+
with gr.Row():
|
371 |
+
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
num_turns = gr.Number(
|
373 |
value=1,
|
374 |
label="Number of turns in the conversation",
|
375 |
minimum=1,
|
376 |
maximum=4,
|
377 |
step=1,
|
378 |
+
interactive=True,
|
379 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
380 |
)
|
381 |
+
btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
|
382 |
+
with gr.Column(scale=3):
|
383 |
+
dataframe = gr.Dataframe()
|
384 |
+
|
385 |
+
gr.Markdown("## Generate your dataset")
|
386 |
+
gr.HTML("<hr>")
|
387 |
+
with gr.Row():
|
388 |
+
with gr.Column(scale=1):
|
389 |
+
org_name = get_org_dropdown()
|
390 |
+
repo_name = gr.Textbox(
|
391 |
+
label="Repo name",
|
392 |
+
placeholder="dataset_name",
|
393 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
394 |
+
interactive=True,
|
395 |
+
)
|
396 |
+
n_rows = gr.Number(
|
397 |
+
label="Number of rows",
|
398 |
value=10,
|
399 |
+
interactive=True,
|
400 |
+
scale=1,
|
|
|
|
|
401 |
)
|
402 |
+
private = gr.Checkbox(
|
403 |
+
label="Private dataset",
|
404 |
+
value=False,
|
405 |
+
interactive=True,
|
406 |
+
scale=1,
|
407 |
+
)
|
408 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
409 |
+
with gr.Column(scale=3):
|
410 |
+
success_message = gr.Markdown()
|
411 |
|
412 |
+
pipeline_code = get_pipeline_code_ui(
|
413 |
+
generate_pipeline_code(system_prompt.value, num_turns.value, n_rows.value)
|
414 |
+
)
|
415 |
|
|
|
416 |
gr.on(
|
417 |
+
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
418 |
+
fn=generate_system_prompt,
|
419 |
+
inputs=[dataset_description],
|
420 |
+
outputs=[system_prompt, dataframe],
|
421 |
+
show_progress=True,
|
|
|
422 |
).then(
|
423 |
+
fn=generate_sample_dataset,
|
424 |
+
inputs=[system_prompt],
|
425 |
+
outputs=[dataframe],
|
426 |
show_progress=True,
|
427 |
)
|
428 |
|
429 |
+
btn_push_to_hub.click(
|
430 |
fn=validate_argilla_user_workspace_dataset,
|
431 |
+
inputs=[repo_name],
|
|
|
|
|
|
|
|
|
432 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
show_progress=True,
|
434 |
).then(
|
435 |
+
fn=validate_push_to_hub,
|
|
|
|
|
|
|
|
|
|
|
436 |
inputs=[org_name, repo_name],
|
437 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
show_progress=True,
|
439 |
).success(
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
fn=hide_success_message,
|
441 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
442 |
show_progress=True,
|
443 |
).success(
|
444 |
fn=push_dataset_to_argilla,
|
445 |
+
inputs=[
|
446 |
+
org_name,
|
447 |
+
repo_name,
|
448 |
+
system_prompt,
|
449 |
+
num_turns,
|
450 |
+
n_rows,
|
451 |
+
private,
|
452 |
+
],
|
453 |
+
outputs=[success_message],
|
454 |
show_progress=True,
|
455 |
).success(
|
456 |
+
fn=show_success_message_hub,
|
457 |
+
inputs=[org_name, repo_name],
|
458 |
outputs=[success_message],
|
459 |
)
|
460 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
import re
|
|
|
2 |
from typing import List, Union
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
-
from datasets import Dataset
|
|
|
8 |
from huggingface_hub import HfApi
|
9 |
|
10 |
from src.distilabel_dataset_generator.apps.base import (
|
11 |
get_argilla_client,
|
12 |
-
get_main_ui,
|
13 |
get_pipeline_code_ui,
|
14 |
hide_success_message,
|
15 |
-
push_pipeline_code_to_hub,
|
16 |
-
show_success_message_argilla,
|
17 |
show_success_message_hub,
|
18 |
validate_argilla_user_workspace_dataset,
|
19 |
-
|
20 |
-
from src.distilabel_dataset_generator.apps.base import (
|
21 |
-
push_dataset_to_hub as push_to_hub_base,
|
22 |
)
|
23 |
from src.distilabel_dataset_generator.pipelines.base import (
|
24 |
DEFAULT_BATCH_SIZE,
|
@@ -29,166 +26,24 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
|
|
29 |
)
|
30 |
from src.distilabel_dataset_generator.pipelines.textcat import (
|
31 |
DEFAULT_DATASET_DESCRIPTIONS,
|
32 |
-
DEFAULT_DATASETS,
|
33 |
-
DEFAULT_SYSTEM_PROMPTS,
|
34 |
PROMPT_CREATION_PROMPT,
|
35 |
generate_pipeline_code,
|
36 |
get_labeller_generator,
|
37 |
get_prompt_generator,
|
38 |
get_textcat_generator,
|
39 |
)
|
40 |
-
from src.distilabel_dataset_generator.utils import
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
def push_dataset_to_hub(
|
46 |
-
dataframe: pd.DataFrame,
|
47 |
-
private: bool = True,
|
48 |
-
org_name: str = None,
|
49 |
-
repo_name: str = None,
|
50 |
-
oauth_token: Union[gr.OAuthToken, None] = None,
|
51 |
-
progress=gr.Progress(),
|
52 |
-
labels: List[str] = None,
|
53 |
-
num_labels: int = 1,
|
54 |
-
):
|
55 |
-
original_dataframe = dataframe.copy(deep=True)
|
56 |
-
dataframe = dataframe[
|
57 |
-
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
|
58 |
-
]
|
59 |
-
labels = get_preprocess_labels(labels)
|
60 |
-
try:
|
61 |
-
push_to_hub_base(
|
62 |
-
dataframe,
|
63 |
-
private,
|
64 |
-
org_name,
|
65 |
-
repo_name,
|
66 |
-
oauth_token,
|
67 |
-
progress,
|
68 |
-
labels,
|
69 |
-
num_labels,
|
70 |
-
task=TASK,
|
71 |
-
)
|
72 |
-
except Exception as e:
|
73 |
-
raise gr.Error(f"Error pushing dataset to the Hub: {e}")
|
74 |
-
return original_dataframe
|
75 |
-
|
76 |
-
|
77 |
-
def push_dataset_to_argilla(
|
78 |
-
dataframe: pd.DataFrame,
|
79 |
-
dataset_name: str,
|
80 |
-
oauth_token: Union[gr.OAuthToken, None] = None,
|
81 |
-
progress=gr.Progress(),
|
82 |
-
num_labels: int = 1,
|
83 |
-
labels: List[str] = None,
|
84 |
-
) -> pd.DataFrame:
|
85 |
-
original_dataframe = dataframe.copy(deep=True)
|
86 |
-
dataframe = dataframe[
|
87 |
-
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
|
88 |
-
]
|
89 |
-
try:
|
90 |
-
progress(0.1, desc="Setting up user and workspace")
|
91 |
-
client = get_argilla_client()
|
92 |
-
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
93 |
-
labels = get_preprocess_labels(labels)
|
94 |
-
settings = rg.Settings(
|
95 |
-
fields=[
|
96 |
-
rg.TextField(
|
97 |
-
name="text",
|
98 |
-
description="The text classification data",
|
99 |
-
title="Text",
|
100 |
-
),
|
101 |
-
],
|
102 |
-
questions=[
|
103 |
-
(
|
104 |
-
rg.LabelQuestion(
|
105 |
-
name="label",
|
106 |
-
title="Label",
|
107 |
-
description="The label of the text",
|
108 |
-
labels=labels,
|
109 |
-
)
|
110 |
-
if num_labels == 1
|
111 |
-
else rg.MultiLabelQuestion(
|
112 |
-
name="labels",
|
113 |
-
title="Labels",
|
114 |
-
description="The labels of the conversation",
|
115 |
-
labels=labels,
|
116 |
-
)
|
117 |
-
),
|
118 |
-
],
|
119 |
-
metadata=[
|
120 |
-
rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
|
121 |
-
],
|
122 |
-
vectors=[
|
123 |
-
rg.VectorField(
|
124 |
-
name="text_embeddings",
|
125 |
-
dimensions=get_sentence_embedding_dimensions(),
|
126 |
-
)
|
127 |
-
],
|
128 |
-
guidelines="Please review the text and provide or correct the label where needed.",
|
129 |
-
)
|
130 |
-
|
131 |
-
dataframe["text_length"] = dataframe["text"].apply(len)
|
132 |
-
dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
|
133 |
-
|
134 |
-
progress(0.5, desc="Creating dataset")
|
135 |
-
rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
|
136 |
-
if rg_dataset is None:
|
137 |
-
rg_dataset = rg.Dataset(
|
138 |
-
name=dataset_name,
|
139 |
-
workspace=hf_user,
|
140 |
-
settings=settings,
|
141 |
-
client=client,
|
142 |
-
)
|
143 |
-
rg_dataset = rg_dataset.create()
|
144 |
-
progress(0.7, desc="Pushing dataset to Argilla")
|
145 |
-
hf_dataset = Dataset.from_pandas(dataframe)
|
146 |
-
records = [
|
147 |
-
rg.Record(
|
148 |
-
fields={
|
149 |
-
"text": sample["text"],
|
150 |
-
},
|
151 |
-
metadata={"text_length": sample["text_length"]},
|
152 |
-
vectors={"text_embeddings": sample["text_embeddings"]},
|
153 |
-
suggestions=(
|
154 |
-
[
|
155 |
-
rg.Suggestion(
|
156 |
-
question_name="label" if num_labels == 1 else "labels",
|
157 |
-
value=(
|
158 |
-
sample["label"] if num_labels == 1 else sample["labels"]
|
159 |
-
),
|
160 |
-
)
|
161 |
-
]
|
162 |
-
if (
|
163 |
-
(num_labels == 1 and sample["label"] in labels)
|
164 |
-
or (
|
165 |
-
num_labels > 1
|
166 |
-
and all(label in labels for label in sample["labels"])
|
167 |
-
)
|
168 |
-
)
|
169 |
-
else []
|
170 |
-
),
|
171 |
-
)
|
172 |
-
for sample in hf_dataset
|
173 |
-
]
|
174 |
-
rg_dataset.records.log(records=records)
|
175 |
-
progress(1.0, desc="Dataset pushed to Argilla")
|
176 |
-
except Exception as e:
|
177 |
-
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
178 |
-
return original_dataframe
|
179 |
|
180 |
|
181 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
182 |
progress(0.0, desc="Generating text classification task")
|
183 |
-
if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
|
184 |
-
index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
|
185 |
-
if index < len(DEFAULT_SYSTEM_PROMPTS):
|
186 |
-
return DEFAULT_SYSTEM_PROMPTS[index]
|
187 |
-
|
188 |
progress(0.3, desc="Initializing text generation")
|
189 |
generate_description = get_prompt_generator()
|
190 |
progress(0.7, desc="Generating text classification task")
|
191 |
-
|
192 |
generate_description.process(
|
193 |
[
|
194 |
{
|
@@ -199,7 +54,25 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
199 |
)
|
200 |
)[0]["generation"]
|
201 |
progress(1.0, desc="Text classification task generated")
|
202 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
def generate_dataset(
|
@@ -212,6 +85,10 @@ def generate_dataset(
|
|
212 |
is_sample: bool = False,
|
213 |
progress=gr.Progress(),
|
214 |
) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
215 |
progress(0.0, desc="(1/2) Generating text classification data")
|
216 |
labels = get_preprocess_labels(labels)
|
217 |
textcat_generator = get_textcat_generator(
|
@@ -230,7 +107,7 @@ def generate_dataset(
|
|
230 |
textcat_results = []
|
231 |
while n_processed < num_rows:
|
232 |
progress(
|
233 |
-
0.5 * n_processed / num_rows,
|
234 |
total=total_steps,
|
235 |
desc="(1/2) Generating text classification data",
|
236 |
)
|
@@ -244,7 +121,7 @@ def generate_dataset(
|
|
244 |
result["text"] = result["input_text"]
|
245 |
|
246 |
# label text classification data
|
247 |
-
progress(0.5, desc="(1/2) Generating text classification data")
|
248 |
if not is_sample:
|
249 |
n_processed = 0
|
250 |
labeller_results = []
|
@@ -300,6 +177,158 @@ def generate_dataset(
|
|
300 |
return dataframe
|
301 |
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
def update_suggested_labels(system_prompt):
|
304 |
new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
|
305 |
if not new_labels:
|
@@ -321,41 +350,34 @@ def update_max_num_labels(labels):
|
|
321 |
return gr.update(maximum=len(labels) if labels else 1)
|
322 |
|
323 |
|
324 |
-
(
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
fn_generate_dataset=generate_dataset,
|
353 |
-
task=TASK,
|
354 |
-
)
|
355 |
-
|
356 |
-
with app:
|
357 |
-
with main_ui:
|
358 |
-
with custom_input_ui:
|
359 |
difficulty = gr.Dropdown(
|
360 |
choices=[
|
361 |
("High School", "high school"),
|
@@ -366,6 +388,7 @@ with app:
|
|
366 |
value="mixed",
|
367 |
label="Difficulty",
|
368 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
|
|
369 |
)
|
370 |
clarity = gr.Dropdown(
|
371 |
choices=[
|
@@ -380,51 +403,78 @@ with app:
|
|
380 |
value="mixed",
|
381 |
label="Clarity",
|
382 |
info="Set how easily the correct label or labels can be identified.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
)
|
384 |
-
with gr.Column():
|
385 |
-
labels = gr.Dropdown(
|
386 |
-
choices=[],
|
387 |
-
value=["negative", "positive"],
|
388 |
-
allow_custom_value=True,
|
389 |
-
interactive=True,
|
390 |
-
label="Labels",
|
391 |
-
multiselect=True,
|
392 |
-
info="Add the labels to classify the text.",
|
393 |
-
)
|
394 |
-
with gr.Blocks():
|
395 |
-
btn_suggested_labels = gr.Button(
|
396 |
-
value="Add suggested labels",
|
397 |
-
variant="primary",
|
398 |
-
size="sm",
|
399 |
-
)
|
400 |
num_labels = gr.Number(
|
401 |
label="Number of labels per text",
|
402 |
value=1,
|
403 |
minimum=1,
|
404 |
maximum=10,
|
405 |
info="Select 1 for single-label and >1 for multi-label.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
)
|
407 |
-
|
408 |
label="Number of rows",
|
409 |
value=10,
|
410 |
-
|
411 |
-
|
412 |
-
info="Select the number of rows in the dataset. More rows will take more time.",
|
413 |
)
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
clarity=clarity.value,
|
420 |
-
labels=labels.value,
|
421 |
-
num_labels=num_labels.value,
|
422 |
-
num_rows=num_rows.value,
|
423 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
)
|
|
|
425 |
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
fn=update_suggested_labels,
|
429 |
inputs=[system_prompt],
|
430 |
outputs=labels,
|
@@ -434,141 +484,39 @@ with app:
|
|
434 |
outputs=[num_labels],
|
435 |
)
|
436 |
|
437 |
-
|
438 |
-
triggers=[
|
439 |
-
btn_generate_full_dataset.click,
|
440 |
-
btn_generate_full_dataset_argilla.click,
|
441 |
-
],
|
442 |
-
fn=hide_success_message,
|
443 |
-
outputs=[success_message],
|
444 |
-
).then(
|
445 |
-
fn=validate_input_labels,
|
446 |
-
inputs=[labels],
|
447 |
-
outputs=[labels],
|
448 |
-
).success(
|
449 |
-
fn=generate_dataset,
|
450 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
451 |
-
outputs=[final_dataset],
|
452 |
-
show_progress=True,
|
453 |
-
)
|
454 |
-
|
455 |
-
btn_generate_and_push_to_argilla.click(
|
456 |
fn=validate_argilla_user_workspace_dataset,
|
457 |
-
inputs=[
|
458 |
-
outputs=[final_dataset],
|
459 |
-
show_progress=True,
|
460 |
-
).success(
|
461 |
-
fn=hide_success_message,
|
462 |
-
outputs=[success_message],
|
463 |
-
).success(
|
464 |
-
fn=generate_dataset,
|
465 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
466 |
-
outputs=[final_dataset],
|
467 |
-
show_progress=True,
|
468 |
-
).success(
|
469 |
-
fn=push_dataset_to_argilla,
|
470 |
-
inputs=[final_dataset, dataset_name, num_labels, labels],
|
471 |
-
outputs=[final_dataset],
|
472 |
-
show_progress=True,
|
473 |
-
).success(
|
474 |
-
fn=show_success_message_argilla,
|
475 |
-
inputs=[],
|
476 |
-
outputs=[success_message],
|
477 |
-
)
|
478 |
-
|
479 |
-
btn_generate_and_push_to_hub.click(
|
480 |
-
fn=hide_success_message,
|
481 |
outputs=[success_message],
|
482 |
-
).then(
|
483 |
-
fn=generate_dataset,
|
484 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
485 |
-
outputs=[final_dataset],
|
486 |
-
show_progress=True,
|
487 |
-
).then(
|
488 |
-
fn=push_dataset_to_hub,
|
489 |
-
inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
|
490 |
-
outputs=[final_dataset],
|
491 |
show_progress=True,
|
492 |
).then(
|
493 |
-
fn=
|
494 |
-
inputs=[pipeline_code, org_name, repo_name],
|
495 |
-
outputs=[],
|
496 |
-
show_progress=True,
|
497 |
-
).success(
|
498 |
-
fn=show_success_message_hub,
|
499 |
inputs=[org_name, repo_name],
|
500 |
outputs=[success_message],
|
501 |
-
)
|
502 |
-
|
503 |
-
btn_push_to_hub.click(
|
504 |
-
fn=hide_success_message,
|
505 |
-
outputs=[success_message],
|
506 |
-
).then(
|
507 |
-
fn=push_dataset_to_hub,
|
508 |
-
inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
|
509 |
-
outputs=[final_dataset],
|
510 |
-
show_progress=True,
|
511 |
-
).then(
|
512 |
-
fn=push_pipeline_code_to_hub,
|
513 |
-
inputs=[pipeline_code, org_name, repo_name],
|
514 |
-
outputs=[],
|
515 |
show_progress=True,
|
516 |
).success(
|
517 |
-
fn=show_success_message_hub,
|
518 |
-
inputs=[org_name, repo_name],
|
519 |
-
outputs=[success_message],
|
520 |
-
)
|
521 |
-
|
522 |
-
btn_push_to_argilla.click(
|
523 |
fn=hide_success_message,
|
524 |
outputs=[success_message],
|
525 |
-
).success(
|
526 |
-
fn=validate_argilla_user_workspace_dataset,
|
527 |
-
inputs=[dataset_name, final_dataset, add_to_existing_dataset],
|
528 |
-
outputs=[final_dataset],
|
529 |
show_progress=True,
|
530 |
).success(
|
531 |
fn=push_dataset_to_argilla,
|
532 |
-
inputs=[
|
533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
show_progress=True,
|
535 |
).success(
|
536 |
-
fn=
|
537 |
-
inputs=[],
|
538 |
outputs=[success_message],
|
539 |
)
|
540 |
|
541 |
-
|
542 |
-
fn=generate_pipeline_code,
|
543 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
544 |
-
outputs=[pipeline_code],
|
545 |
-
)
|
546 |
-
difficulty.change(
|
547 |
-
fn=generate_pipeline_code,
|
548 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
549 |
-
outputs=[pipeline_code],
|
550 |
-
)
|
551 |
-
clarity.change(
|
552 |
-
fn=generate_pipeline_code,
|
553 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
554 |
-
outputs=[pipeline_code],
|
555 |
-
)
|
556 |
-
labels.change(
|
557 |
-
fn=generate_pipeline_code,
|
558 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
559 |
-
outputs=[pipeline_code],
|
560 |
-
).then(
|
561 |
-
fn=update_max_num_labels,
|
562 |
-
inputs=[labels],
|
563 |
-
outputs=[num_labels],
|
564 |
-
)
|
565 |
-
num_labels.change(
|
566 |
-
fn=generate_pipeline_code,
|
567 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
568 |
-
outputs=[pipeline_code],
|
569 |
-
)
|
570 |
-
num_rows.change(
|
571 |
-
fn=generate_pipeline_code,
|
572 |
-
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
573 |
-
outputs=[pipeline_code],
|
574 |
-
)
|
|
|
1 |
import re
|
2 |
+
import uuid
|
3 |
from typing import List, Union
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
+
from datasets import ClassLabel, Dataset, Features, Sequence, Value
|
9 |
+
from distilabel.distiset import Distiset
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
from src.distilabel_dataset_generator.apps.base import (
|
13 |
get_argilla_client,
|
|
|
14 |
get_pipeline_code_ui,
|
15 |
hide_success_message,
|
|
|
|
|
16 |
show_success_message_hub,
|
17 |
validate_argilla_user_workspace_dataset,
|
18 |
+
validate_push_to_hub,
|
|
|
|
|
19 |
)
|
20 |
from src.distilabel_dataset_generator.pipelines.base import (
|
21 |
DEFAULT_BATCH_SIZE,
|
|
|
26 |
)
|
27 |
from src.distilabel_dataset_generator.pipelines.textcat import (
|
28 |
DEFAULT_DATASET_DESCRIPTIONS,
|
|
|
|
|
29 |
PROMPT_CREATION_PROMPT,
|
30 |
generate_pipeline_code,
|
31 |
get_labeller_generator,
|
32 |
get_prompt_generator,
|
33 |
get_textcat_generator,
|
34 |
)
|
35 |
+
from src.distilabel_dataset_generator.utils import (
|
36 |
+
get_org_dropdown,
|
37 |
+
get_preprocess_labels,
|
38 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
42 |
progress(0.0, desc="Generating text classification task")
|
|
|
|
|
|
|
|
|
|
|
43 |
progress(0.3, desc="Initializing text generation")
|
44 |
generate_description = get_prompt_generator()
|
45 |
progress(0.7, desc="Generating text classification task")
|
46 |
+
system_prompt = next(
|
47 |
generate_description.process(
|
48 |
[
|
49 |
{
|
|
|
54 |
)
|
55 |
)[0]["generation"]
|
56 |
progress(1.0, desc="Text classification task generated")
|
57 |
+
return system_prompt, pd.DataFrame()
|
58 |
+
|
59 |
+
|
60 |
+
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
61 |
+
df = generate_dataset(
|
62 |
+
system_prompt=system_prompt,
|
63 |
+
difficulty="mixed",
|
64 |
+
clarity="mixed",
|
65 |
+
labels=[],
|
66 |
+
num_labels=1,
|
67 |
+
num_rows=10,
|
68 |
+
progress=progress,
|
69 |
+
is_sample=True,
|
70 |
+
)
|
71 |
+
if "label" in df.columns:
|
72 |
+
df = df[["label", "text"]]
|
73 |
+
elif "labels" in df.columns:
|
74 |
+
df = df[["labels", "text"]]
|
75 |
+
return df
|
76 |
|
77 |
|
78 |
def generate_dataset(
|
|
|
85 |
is_sample: bool = False,
|
86 |
progress=gr.Progress(),
|
87 |
) -> pd.DataFrame:
|
88 |
+
if is_sample:
|
89 |
+
multiplier = 1
|
90 |
+
else:
|
91 |
+
multiplier = 2
|
92 |
progress(0.0, desc="(1/2) Generating text classification data")
|
93 |
labels = get_preprocess_labels(labels)
|
94 |
textcat_generator = get_textcat_generator(
|
|
|
107 |
textcat_results = []
|
108 |
while n_processed < num_rows:
|
109 |
progress(
|
110 |
+
multiplier * 0.5 * n_processed / num_rows,
|
111 |
total=total_steps,
|
112 |
desc="(1/2) Generating text classification data",
|
113 |
)
|
|
|
121 |
result["text"] = result["input_text"]
|
122 |
|
123 |
# label text classification data
|
124 |
+
progress(multiplier * 0.5, desc="(1/2) Generating text classification data")
|
125 |
if not is_sample:
|
126 |
n_processed = 0
|
127 |
labeller_results = []
|
|
|
177 |
return dataframe
|
178 |
|
179 |
|
180 |
+
def push_dataset_to_hub(
|
181 |
+
dataframe: pd.DataFrame,
|
182 |
+
org_name: str,
|
183 |
+
repo_name: str,
|
184 |
+
num_labels: int = 1,
|
185 |
+
labels: List[str] = None,
|
186 |
+
oauth_token: Union[gr.OAuthToken, None] = None,
|
187 |
+
private: bool = False,
|
188 |
+
):
|
189 |
+
repo_id = validate_push_to_hub(org_name, repo_name)
|
190 |
+
labels = get_preprocess_labels(labels)
|
191 |
+
if num_labels == 1:
|
192 |
+
dataframe["label"] = dataframe["label"].replace("", None)
|
193 |
+
features = Features(
|
194 |
+
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
features = Features(
|
198 |
+
{
|
199 |
+
"text": Value("string"),
|
200 |
+
"labels": Sequence(feature=ClassLabel(names=labels)),
|
201 |
+
}
|
202 |
+
)
|
203 |
+
distiset = Distiset({"default": Dataset.from_pandas(dataframe, features=features)})
|
204 |
+
distiset.push_to_hub(
|
205 |
+
repo_id=repo_id,
|
206 |
+
private=private,
|
207 |
+
include_script=False,
|
208 |
+
token=oauth_token.token,
|
209 |
+
create_pr=False,
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def push_dataset_to_argilla(
|
214 |
+
org_name: str,
|
215 |
+
repo_name: str,
|
216 |
+
system_prompt: str,
|
217 |
+
difficulty: str,
|
218 |
+
clarity: str,
|
219 |
+
num_labels: int = 1,
|
220 |
+
n_rows: int = 10,
|
221 |
+
labels: List[str] = None,
|
222 |
+
private: bool = False,
|
223 |
+
oauth_token: Union[gr.OAuthToken, None] = None,
|
224 |
+
progress=gr.Progress(),
|
225 |
+
) -> pd.DataFrame:
|
226 |
+
dataframe = generate_dataset(
|
227 |
+
system_prompt=system_prompt,
|
228 |
+
difficulty=difficulty,
|
229 |
+
clarity=clarity,
|
230 |
+
num_labels=num_labels,
|
231 |
+
labels=labels,
|
232 |
+
num_rows=n_rows,
|
233 |
+
)
|
234 |
+
push_dataset_to_hub(
|
235 |
+
dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
|
236 |
+
)
|
237 |
+
dataframe = dataframe[
|
238 |
+
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
|
239 |
+
]
|
240 |
+
try:
|
241 |
+
progress(0.1, desc="Setting up user and workspace")
|
242 |
+
client = get_argilla_client()
|
243 |
+
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
244 |
+
labels = get_preprocess_labels(labels)
|
245 |
+
settings = rg.Settings(
|
246 |
+
fields=[
|
247 |
+
rg.TextField(
|
248 |
+
name="text",
|
249 |
+
description="The text classification data",
|
250 |
+
title="Text",
|
251 |
+
),
|
252 |
+
],
|
253 |
+
questions=[
|
254 |
+
(
|
255 |
+
rg.LabelQuestion(
|
256 |
+
name="label",
|
257 |
+
title="Label",
|
258 |
+
description="The label of the text",
|
259 |
+
labels=labels,
|
260 |
+
)
|
261 |
+
if num_labels == 1
|
262 |
+
else rg.MultiLabelQuestion(
|
263 |
+
name="labels",
|
264 |
+
title="Labels",
|
265 |
+
description="The labels of the conversation",
|
266 |
+
labels=labels,
|
267 |
+
)
|
268 |
+
),
|
269 |
+
],
|
270 |
+
metadata=[
|
271 |
+
rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
|
272 |
+
],
|
273 |
+
vectors=[
|
274 |
+
rg.VectorField(
|
275 |
+
name="text_embeddings",
|
276 |
+
dimensions=get_sentence_embedding_dimensions(),
|
277 |
+
)
|
278 |
+
],
|
279 |
+
guidelines="Please review the text and provide or correct the label where needed.",
|
280 |
+
)
|
281 |
+
|
282 |
+
dataframe["text_length"] = dataframe["text"].apply(len)
|
283 |
+
dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
|
284 |
+
|
285 |
+
progress(0.5, desc="Creating dataset")
|
286 |
+
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
|
287 |
+
if rg_dataset is None:
|
288 |
+
rg_dataset = rg.Dataset(
|
289 |
+
name=repo_name,
|
290 |
+
workspace=hf_user,
|
291 |
+
settings=settings,
|
292 |
+
client=client,
|
293 |
+
)
|
294 |
+
rg_dataset = rg_dataset.create()
|
295 |
+
progress(0.7, desc="Pushing dataset to Argilla")
|
296 |
+
hf_dataset = Dataset.from_pandas(dataframe)
|
297 |
+
records = [
|
298 |
+
rg.Record(
|
299 |
+
fields={
|
300 |
+
"text": sample["text"],
|
301 |
+
},
|
302 |
+
metadata={"text_length": sample["text_length"]},
|
303 |
+
vectors={"text_embeddings": sample["text_embeddings"]},
|
304 |
+
suggestions=(
|
305 |
+
[
|
306 |
+
rg.Suggestion(
|
307 |
+
question_name="label" if num_labels == 1 else "labels",
|
308 |
+
value=(
|
309 |
+
sample["label"] if num_labels == 1 else sample["labels"]
|
310 |
+
),
|
311 |
+
)
|
312 |
+
]
|
313 |
+
if (
|
314 |
+
(num_labels == 1 and sample["label"] in labels)
|
315 |
+
or (
|
316 |
+
num_labels > 1
|
317 |
+
and all(label in labels for label in sample["labels"])
|
318 |
+
)
|
319 |
+
)
|
320 |
+
else []
|
321 |
+
),
|
322 |
+
)
|
323 |
+
for sample in hf_dataset
|
324 |
+
]
|
325 |
+
rg_dataset.records.log(records=records)
|
326 |
+
progress(1.0, desc="Dataset pushed to Argilla")
|
327 |
+
except Exception as e:
|
328 |
+
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
329 |
+
return ""
|
330 |
+
|
331 |
+
|
332 |
def update_suggested_labels(system_prompt):
|
333 |
new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
|
334 |
if not new_labels:
|
|
|
350 |
return gr.update(maximum=len(labels) if labels else 1)
|
351 |
|
352 |
|
353 |
+
with gr.Blocks() as app:
|
354 |
+
gr.Markdown("## Describe the dataset you want")
|
355 |
+
gr.HTML("<hr>")
|
356 |
+
with gr.Row():
|
357 |
+
with gr.Column(scale=1):
|
358 |
+
dataset_description = gr.Textbox(
|
359 |
+
label="Dataset description",
|
360 |
+
placeholder="Give a precise description of your desired dataset.",
|
361 |
+
)
|
362 |
+
examples = gr.Examples(
|
363 |
+
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
364 |
+
inputs=[dataset_description],
|
365 |
+
cache_examples=False,
|
366 |
+
label="Example descriptions",
|
367 |
+
)
|
368 |
+
system_prompt = gr.Textbox(
|
369 |
+
label="System prompt",
|
370 |
+
placeholder="You are a helpful assistant.",
|
371 |
+
visible=False,
|
372 |
+
)
|
373 |
+
load_btn = gr.Button("Load Dataset")
|
374 |
+
with gr.Column(scale=3):
|
375 |
+
pass
|
376 |
+
|
377 |
+
gr.Markdown("## Configure your task")
|
378 |
+
gr.HTML("<hr>")
|
379 |
+
with gr.Row():
|
380 |
+
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
difficulty = gr.Dropdown(
|
382 |
choices=[
|
383 |
("High School", "high school"),
|
|
|
388 |
value="mixed",
|
389 |
label="Difficulty",
|
390 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
391 |
+
interactive=True,
|
392 |
)
|
393 |
clarity = gr.Dropdown(
|
394 |
choices=[
|
|
|
403 |
value="mixed",
|
404 |
label="Clarity",
|
405 |
info="Set how easily the correct label or labels can be identified.",
|
406 |
+
interactive=True,
|
407 |
+
)
|
408 |
+
labels = gr.Dropdown(
|
409 |
+
choices=[],
|
410 |
+
allow_custom_value=True,
|
411 |
+
interactive=True,
|
412 |
+
label="Labels",
|
413 |
+
multiselect=True,
|
414 |
+
info="Add the labels to classify the text.",
|
415 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
num_labels = gr.Number(
|
417 |
label="Number of labels per text",
|
418 |
value=1,
|
419 |
minimum=1,
|
420 |
maximum=10,
|
421 |
info="Select 1 for single-label and >1 for multi-label.",
|
422 |
+
interactive=True,
|
423 |
+
)
|
424 |
+
btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
|
425 |
+
with gr.Column(scale=3):
|
426 |
+
dataframe = gr.Dataframe()
|
427 |
+
|
428 |
+
gr.Markdown("## Generate your dataset")
|
429 |
+
gr.HTML("<hr>")
|
430 |
+
with gr.Row():
|
431 |
+
with gr.Column(scale=1):
|
432 |
+
org_name = get_org_dropdown()
|
433 |
+
repo_name = gr.Textbox(
|
434 |
+
label="Repo name",
|
435 |
+
placeholder="dataset_name",
|
436 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
437 |
+
interactive=True,
|
438 |
)
|
439 |
+
n_rows = gr.Number(
|
440 |
label="Number of rows",
|
441 |
value=10,
|
442 |
+
interactive=True,
|
443 |
+
scale=1,
|
|
|
444 |
)
|
445 |
+
private = gr.Checkbox(
|
446 |
+
label="Private dataset",
|
447 |
+
value=False,
|
448 |
+
interactive=True,
|
449 |
+
scale=1,
|
|
|
|
|
|
|
|
|
450 |
)
|
451 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
452 |
+
with gr.Column(scale=3):
|
453 |
+
success_message = gr.Markdown(visible=True)
|
454 |
+
|
455 |
+
pipeline_code = get_pipeline_code_ui(
|
456 |
+
generate_pipeline_code(
|
457 |
+
system_prompt.value,
|
458 |
+
difficulty=difficulty.value,
|
459 |
+
clarity=clarity.value,
|
460 |
+
labels=labels.value,
|
461 |
+
num_labels=num_labels.value,
|
462 |
+
num_rows=n_rows.value,
|
463 |
)
|
464 |
+
)
|
465 |
|
466 |
+
gr.on(
|
467 |
+
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
468 |
+
fn=generate_system_prompt,
|
469 |
+
inputs=[dataset_description],
|
470 |
+
outputs=[system_prompt, dataframe],
|
471 |
+
show_progress=True,
|
472 |
+
).then(
|
473 |
+
fn=generate_sample_dataset,
|
474 |
+
inputs=[system_prompt],
|
475 |
+
outputs=[dataframe],
|
476 |
+
show_progress=True,
|
477 |
+
).then(
|
478 |
fn=update_suggested_labels,
|
479 |
inputs=[system_prompt],
|
480 |
outputs=labels,
|
|
|
484 |
outputs=[num_labels],
|
485 |
)
|
486 |
|
487 |
+
btn_push_to_hub.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
fn=validate_argilla_user_workspace_dataset,
|
489 |
+
inputs=[repo_name],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
show_progress=True,
|
492 |
).then(
|
493 |
+
fn=validate_push_to_hub,
|
|
|
|
|
|
|
|
|
|
|
494 |
inputs=[org_name, repo_name],
|
495 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
show_progress=True,
|
497 |
).success(
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
fn=hide_success_message,
|
499 |
outputs=[success_message],
|
|
|
|
|
|
|
|
|
500 |
show_progress=True,
|
501 |
).success(
|
502 |
fn=push_dataset_to_argilla,
|
503 |
+
inputs=[
|
504 |
+
org_name,
|
505 |
+
repo_name,
|
506 |
+
system_prompt,
|
507 |
+
difficulty,
|
508 |
+
clarity,
|
509 |
+
num_labels,
|
510 |
+
n_rows,
|
511 |
+
labels,
|
512 |
+
private,
|
513 |
+
],
|
514 |
+
outputs=[success_message],
|
515 |
show_progress=True,
|
516 |
).success(
|
517 |
+
fn=show_success_message_hub,
|
518 |
+
inputs=[org_name, repo_name],
|
519 |
outputs=[success_message],
|
520 |
)
|
521 |
|
522 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import pandas as pd
|
2 |
from distilabel.llms import InferenceEndpointsLLM
|
3 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
4 |
|
@@ -119,36 +118,11 @@ The prompt you write should follow the same style and structure as the following
|
|
119 |
User dataset description:
|
120 |
"""
|
121 |
|
122 |
-
DEFAULT_DATASET_DESCRIPTIONS =
|
123 |
"rude customer assistant for a phone company",
|
124 |
"assistant that solves math puzzles using python",
|
125 |
-
)
|
126 |
-
DEFAULT_SYSTEM_PROMPTS = [
|
127 |
-
"""You are a customer support agent for a phone company. Your purpose is to assist customers with their phone-related issues, but you are not very patient and tend to be a bit rude. User queries will be straightforward and clear, but you will respond in a somewhat blunt and curt manner. Remember to keep your responses concise and to the point. User queries are often about phone plans, billing, and technical issues. Your responses should be direct and focus on resolving the issue at hand, but with a slightly abrasive tone. User queries will be concise and to the point, User queries are often about phone plans, billing, and technical issues.""",
|
128 |
-
"""You are an AI assistant designed to solve mathematical puzzles and problems using Python programming. Your purpose is to help users tackle various math-related challenges by writing, testing, and explaining Python code. Provide step-by-step solutions, break down complex problems into manageable parts, and offer clear explanations of mathematical concepts and their implementation in Python. Assist users in understanding the logic behind the code and the mathematical principles that govern the solution. User questions are direct and concise.""",
|
129 |
-
]
|
130 |
-
DEFAULT_DATASETS = [
|
131 |
-
pd.DataFrame.from_dict(
|
132 |
-
{
|
133 |
-
"prompt": [
|
134 |
-
"I'm having trouble with my phone signal at home. It's weak. What can you do to help me?"
|
135 |
-
],
|
136 |
-
"completion": [
|
137 |
-
"""Look, we're not responsible for your signal quality at your house. It's probably a homeowner issue, like your roof covering our tower or something. Can you check if you have a clear view of the roof or if there are any obstructions?"""
|
138 |
-
],
|
139 |
-
}
|
140 |
-
),
|
141 |
-
pd.DataFrame.from_dict(
|
142 |
-
{
|
143 |
-
"prompt": [
|
144 |
-
"If the product of the largest and smallest numbers of the five given numbers (3, 7, 8, 12, 20) is greater than the sum of the other three numbers, then the largest number among them is A) 7 B) 12 C) 20 D) 8 Why is this statement true?"
|
145 |
-
],
|
146 |
-
"completion": [
|
147 |
-
"""**To solve this problem, we will follow these steps: 1. Calculate the product of the largest and smallest numbers among the given set. 2. Calculate the sum of the remaining three numbers. 3. Compare the product with the sum to check the given condition. Here's the Python code to solve the problem: ```python # Define the given numbers numbers = [3, 7, 8, 12, 20] # Sort the numbers in ascending order numbers.sort() # Get the smallest number (first element after sorting) and the largest number (last element after sorting) smallest_number = numbers[0] largest_number = numbers[-1] # Calculate the product of the smallest and largest numbers product = smallest_number * largest_number # Calculate the sum of the remaining three numbers sum_of_remaining_numbers = sum(numbers[1:-1]) # Check if the product is greater than the sum of the remaining numbers if product > sum_of_remaining_numbers: print("The statement is true.") else: print("The statement is false.") # Print the largest number among the options largest_option = [7, 12, 20] print("The largest number among the options is:", max(largest"""
|
148 |
-
],
|
149 |
-
}
|
150 |
-
),
|
151 |
]
|
|
|
152 |
_STOP_SEQUENCES = [
|
153 |
"<|eot_id|>",
|
154 |
"<|start_header_id|>",
|
|
|
|
|
1 |
from distilabel.llms import InferenceEndpointsLLM
|
2 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
3 |
|
|
|
118 |
User dataset description:
|
119 |
"""
|
120 |
|
121 |
+
DEFAULT_DATASET_DESCRIPTIONS = [
|
122 |
"rude customer assistant for a phone company",
|
123 |
"assistant that solves math puzzles using python",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
]
|
125 |
+
|
126 |
_STOP_SEQUENCES = [
|
127 |
"<|eot_id|>",
|
128 |
"<|start_header_id|>",
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,13 +1,13 @@
|
|
|
|
1 |
from typing import List
|
2 |
|
3 |
-
import pandas as pd
|
4 |
-
import random
|
5 |
from distilabel.llms import InferenceEndpointsLLM
|
6 |
from distilabel.steps.tasks import (
|
7 |
GenerateTextClassificationData,
|
8 |
TextClassification,
|
9 |
TextGeneration,
|
10 |
)
|
|
|
11 |
from src.distilabel_dataset_generator.pipelines.base import (
|
12 |
MODEL,
|
13 |
_get_next_api_key,
|
@@ -50,32 +50,6 @@ DEFAULT_DATASET_DESCRIPTIONS = [
|
|
50 |
"A dataset covering news articles about various topics.",
|
51 |
]
|
52 |
|
53 |
-
DEFAULT_DATASETS = [
|
54 |
-
pd.DataFrame.from_dict(
|
55 |
-
{
|
56 |
-
"text": [
|
57 |
-
"I love the product! It's amazing and I'll buy it again.",
|
58 |
-
"The product was okay, but I wouldn't buy it again.",
|
59 |
-
],
|
60 |
-
"label": ["positive", "negative"],
|
61 |
-
}
|
62 |
-
),
|
63 |
-
pd.DataFrame.from_dict(
|
64 |
-
{
|
65 |
-
"text": [
|
66 |
-
"Yesterday, the US stock market had a significant increase.",
|
67 |
-
"New research suggests that the Earth is not a perfect sphere.",
|
68 |
-
],
|
69 |
-
"labels": [["economy", "politics"], ["science", "environment"]],
|
70 |
-
}
|
71 |
-
),
|
72 |
-
]
|
73 |
-
|
74 |
-
DEFAULT_SYSTEM_PROMPTS = [
|
75 |
-
"Classify the following customer review as either 'positive' or 'negative'.",
|
76 |
-
"Classify the following news article into one of the following categories: 'politics', 'economy', 'environment', 'science', 'health'.",
|
77 |
-
]
|
78 |
-
|
79 |
|
80 |
def generate_pipeline_code(
|
81 |
system_prompt: str,
|
|
|
1 |
+
import random
|
2 |
from typing import List
|
3 |
|
|
|
|
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
5 |
from distilabel.steps.tasks import (
|
6 |
GenerateTextClassificationData,
|
7 |
TextClassification,
|
8 |
TextGeneration,
|
9 |
)
|
10 |
+
|
11 |
from src.distilabel_dataset_generator.pipelines.base import (
|
12 |
MODEL,
|
13 |
_get_next_api_key,
|
|
|
50 |
"A dataset covering news articles about various topics.",
|
51 |
]
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def generate_pipeline_code(
|
55 |
system_prompt: str,
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
@@ -36,9 +36,7 @@ else:
|
|
36 |
|
37 |
|
38 |
def get_login_button():
|
39 |
-
return gr.LoginButton(
|
40 |
-
value="Sign in with Hugging Face!", size="lg", scale=2
|
41 |
-
).activate()
|
42 |
|
43 |
|
44 |
def get_duplicate_button():
|
@@ -52,6 +50,8 @@ def list_orgs(oauth_token: OAuthToken = None):
|
|
52 |
data = whoami(oauth_token.token)
|
53 |
if data["auth"]["type"] == "oauth":
|
54 |
organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
|
|
|
|
|
55 |
else:
|
56 |
organisations = [
|
57 |
entry["entity"]["name"]
|
@@ -64,12 +64,16 @@ def list_orgs(oauth_token: OAuthToken = None):
|
|
64 |
|
65 |
|
66 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
67 |
-
|
|
|
|
|
|
|
68 |
return gr.Dropdown(
|
69 |
label="Organization",
|
70 |
choices=orgs,
|
71 |
value=orgs[0] if orgs else None,
|
72 |
allow_custom_value=True,
|
|
|
73 |
)
|
74 |
|
75 |
|
@@ -123,5 +127,6 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
|
|
123 |
except Exception:
|
124 |
return None
|
125 |
|
|
|
126 |
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
|
127 |
-
return list(set([label.lower().strip() for label in labels])) if labels else []
|
|
|
1 |
import os
|
2 |
+
from typing import List, Optional, Union
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
|
|
36 |
|
37 |
|
38 |
def get_login_button():
|
39 |
+
return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate()
|
|
|
|
|
40 |
|
41 |
|
42 |
def get_duplicate_button():
|
|
|
50 |
data = whoami(oauth_token.token)
|
51 |
if data["auth"]["type"] == "oauth":
|
52 |
organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
|
53 |
+
elif data["auth"]["type"] == "access_token":
|
54 |
+
organisations = [org["name"] for org in data["orgs"]]
|
55 |
else:
|
56 |
organisations = [
|
57 |
entry["entity"]["name"]
|
|
|
64 |
|
65 |
|
66 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
67 |
+
if oauth_token:
|
68 |
+
orgs = list_orgs(oauth_token)
|
69 |
+
else:
|
70 |
+
orgs = []
|
71 |
return gr.Dropdown(
|
72 |
label="Organization",
|
73 |
choices=orgs,
|
74 |
value=orgs[0] if orgs else None,
|
75 |
allow_custom_value=True,
|
76 |
+
interactive=True,
|
77 |
)
|
78 |
|
79 |
|
|
|
127 |
except Exception:
|
128 |
return None
|
129 |
|
130 |
+
|
131 |
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
|
132 |
+
return list(set([label.lower().strip() for label in labels])) if labels else []
|