Spaces:
Runtime error
Runtime error
from datasets import get_dataset_config_names, get_dataset_split_names | |
from distilabel.steps.tasks import ( | |
GenerateSentencePair, | |
TextGeneration, | |
) | |
from synthetic_dataset_generator.constants import MAX_NUM_TOKENS | |
from synthetic_dataset_generator.pipelines.base import _get_llm, _get_llm_class | |
DEFAULT_DATASET_DESCRIPTIONS = [ | |
"A dataset to retrieve information from legal documents.", | |
"A dataset to search for economical techniques.", | |
] | |
PROMPT_CREATION_PROMPT = """ | |
You are an AI assistant specialized in designing retrieval-augmented generation (RAG) tasks for dataset generation. | |
Your task is to generate a well-structured and descriptive prompt based on the provided dataset description. Respond with only the generated prompt and nothing else. | |
The prompt should closely follow the style and structure of the example prompts below. Ensure that you include all relevant details from the dataset description. | |
Description: A dataset to retrieve information from legal documents. | |
Output: A dataset to retrieve information from a collection of legal documents related to the US law system and the status of contracts. | |
Description: A dataset to search for economical techniques. | |
Output: A dataset to search for economical techniques and strategies for the European market and the financial sector. | |
Description: A dataset covering FAQ questions for a tech company called Argilla that sells technology datasets within the open-source Natural Language Processing space. | |
Output: A dataset covering FAQ questions for a tech company called Argilla that sells technology datasets within the open-source Natural Language Processing space. | |
Description: | |
""" | |
SYSTEM_PROMPT_CHUCKS = """ | |
You are a helpful and knowledgeable AI assistant. Your task is to generate concise and informative text chunks relevant to the given retrieval task. | |
Ensure the text chunks are: | |
- Focused and directly related to the retrieval task. | |
- Clear, truthful, and based on your general knowledge. | |
Do not include or reference the retrieval task itself in the generated chunks. | |
""" | |
CHUNKS_TEMPLATE = """You have been assigned to generate text chunks based on the following retrieval task: {{ task }}. | |
Provide only the text chunks without explaining your process or reasoning. Do not include any additional information. Do not indicate that it is a text chunk. | |
Ensure the chunks are concise, clear, and directly relevant to the task. | |
Use your general knowledge to create informative and precise outputs. | |
""" | |
SYSTEM_PROMPT_RAG = """ | |
You are a helpful AI assistant. Your task is to answer the following question based on the provided document. | |
If the answer is not explicitly stated in the document, use your knowledge to provide the most relevant and accurate answer possible. | |
If you cannot answer the question based on the given information, state that clearly. | |
""" | |
RAG_TEMPLATE = """Document: | |
{{ context }} | |
Question: {{ question }} | |
Please provide a clear and concise answer to the question based on the information in the document: | |
""".rstrip() | |
def get_prompt_generator(): | |
generation_kwargs = { | |
"temperature": 0.8, | |
"max_new_tokens": MAX_NUM_TOKENS, | |
} | |
text_generator = TextGeneration( | |
llm=_get_llm(generation_kwargs=generation_kwargs), | |
system_prompt=PROMPT_CREATION_PROMPT, | |
use_system_prompt=True, | |
) | |
text_generator.load() | |
return text_generator | |
def get_chunks_generator(temperature: float, is_sample: bool): | |
generation_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256, | |
} | |
text_generator = TextGeneration( | |
llm=_get_llm(generation_kwargs=generation_kwargs), | |
system_prompt=SYSTEM_PROMPT_CHUCKS, | |
template=CHUNKS_TEMPLATE, | |
columns=["task"], | |
use_system_prompt=True, | |
) | |
text_generator.load() | |
return text_generator | |
def get_sentence_pair_generator(action: str, triplet: bool, temperature: float, is_sample: bool): | |
generation_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS, | |
} | |
sentence_pair_generator = GenerateSentencePair( | |
llm=_get_llm(generation_kwargs=generation_kwargs), | |
triplet=triplet, | |
action=action, | |
hard_negative=True, | |
) | |
sentence_pair_generator.load() | |
return sentence_pair_generator | |
def get_response_generator(temperature: float, is_sample: bool): | |
generation_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256, | |
} | |
text_generator = TextGeneration( | |
llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs), | |
system_prompt=SYSTEM_PROMPT_RAG, | |
template=RAG_TEMPLATE, | |
columns=["context", "question"], | |
use_system_prompt=True, | |
) | |
text_generator.load() | |
return text_generator | |
def generate_pipeline_code( | |
repo_id: str, | |
input_type: str, | |
system_prompt: str, | |
document_column: str, | |
retrieval_reranking: list[str], | |
num_rows: int = 10, | |
) -> str: | |
if input_type == "dataset-input" and repo_id is not None: | |
subset = get_dataset_config_names(repo_id)[0] | |
split = get_dataset_split_names(repo_id, subset)[0] | |
else: | |
subset = "default" | |
split = "train" | |
retrieval = "Retrieval" in retrieval_reranking | |
reranking = "Reranking" in retrieval_reranking | |
base_code = f""" | |
# Requirements: `pip install distilabel[hf-inference-endpoints]` | |
{"import random" if input_type == "prompt-input" else ""} | |
from distilabel.models import {_get_llm_class()} | |
from distilabel.pipeline import Pipeline | |
from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}{", CombineOutputs" if retrieval and reranking else ""} | |
from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", GenerateTextRetrievalData" if input_type == "prompt-input" else ""} | |
SYSTEM_PROMPT_RAG = ''' | |
You are a helpful AI assistant. Your task is to answer the following question based on the provided document. | |
If the answer is not explicitly stated in the document, use your knowledge to provide the most relevant and accurate answer possible. | |
If you cannot answer the question based on the given information, state that clearly. | |
''' | |
RAG_TEMPLATE = '''Document: | |
{{{{ filename }}}} | |
Question: {{{{ question }}}} | |
Please provide a clear and concise answer to the question based on the information in the document: | |
'''.rstrip() | |
""" | |
if input_type == "file-input": | |
base_code += """ | |
data = process_and_chunk_files(files=[files]) | |
""" | |
if input_type == "prompt-input": | |
pipeline = f""" | |
TASK_SYSTEM_PROMPT = ''' | |
{system_prompt} | |
''' | |
with Pipeline(name="rag") as pipeline: | |
task_generator = LoadDataFromDicts(data=[{{"task": TASK_SYSTEM_PROMPT}}]) | |
sentence_similarity_generation = GenerateTextRetrievalData( | |
llm={_get_llm_class()}.from_dict( | |
{_get_llm().dump()} | |
), | |
seed=random.randint(0, 2**32 - 1), | |
query_type="common", | |
difficulty="high school", | |
clarity="clear", | |
num_generations={num_rows}, | |
output_mappings={{"positive_document": "anchor"}}, | |
) | |
keep_columns_prompt = KeepColumns( | |
columns=["anchor"], | |
) | |
""" | |
else: | |
pipeline = """ | |
with Pipeline(name="rag") as pipeline: | |
""" | |
if input_type == "file-input": | |
pipeline += """ | |
load_the_dataset = LoadDataFromDicts( | |
data = data, | |
) | |
""" | |
else: | |
pipeline += f""" | |
load_the_dataset = LoadDataFromHub( | |
repo_id="{repo_id}", | |
config="{subset}", | |
split="{split}", | |
num_examples={num_rows}, | |
batch_size=2, | |
output_mappings={{'{document_column}': 'anchor'}} | |
) | |
""" | |
pipeline += f""" | |
generate_retrieval_pairs = GenerateSentencePair( | |
triplet={str(retrieval)}, | |
hard_negative=True, | |
action="query", | |
llm={_get_llm_class()}.from_dict( | |
{_get_llm().dump()} | |
), | |
output_mappings={{"positive": "positive_retrieval"{', "negative": "negative_retrieval"' if retrieval else ""}}}, | |
input_batch_size=10, | |
) | |
""" | |
if reranking: | |
pipeline += f""" | |
generate_reranking_pairs = GenerateSentencePair( | |
triplet=True, | |
hard_negative=True, | |
action="semantically-similar", | |
llm={_get_llm_class()}.from_dict( | |
{_get_llm().dump()} | |
), | |
input_batch_size=10, | |
output_mappings={{"positive": "positive_reranking", "negative": "negative_reranking"}}, | |
) | |
combine_outputs = CombineOutputs() | |
""" | |
pipeline += f""" | |
generate_response = TextGeneration( | |
llm={_get_llm_class()}.from_dict( | |
{_get_llm().dump()} | |
), | |
system_prompt=SYSTEM_PROMPT_RAG, | |
template=RAG_TEMPLATE, | |
columns=["filename", "question"], | |
use_system_prompt=True, | |
input_mappings={{"filename": "anchor", "question": "positive_retrieval"}}, | |
output_mappings={{"generation": "response"}}, | |
) | |
keep_columns = KeepColumns( | |
columns=["anchor", "positive_retrieval", "response"{', "negative_retrieval"' if retrieval else ""}{', "positive_reranking", "negative_reranking"' if reranking else ""}], | |
) | |
""" | |
pipeline_steps = ( | |
"[generate_retrieval_pairs, generate_reranking_pairs] >> combine_outputs >> generate_response >> keep_columns" | |
if reranking | |
else "generate_retrieval_pairs >> generate_response >> keep_columns" | |
) | |
pipeline += """ | |
task_generator >> sentence_similarity_generation >> keep_columns_prompt >> {pipeline_steps} | |
""".format(pipeline_steps=pipeline_steps) if input_type == "prompt-input" else """ | |
load_the_dataset >> {pipeline_steps} | |
""".format(pipeline_steps=pipeline_steps) | |
pipeline += """ | |
if __name__ == "__main__": | |
distiset = pipeline.run() | |
""" | |
return base_code + pipeline | |