Spaces:
Build error
Build error
from abc import ABC, abstractmethod | |
import tiktoken | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.language_models.fake import FakeListLLM | |
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
from langchain_core.prompts import PromptTemplate | |
from langchain_ollama import ChatOllama | |
from loguru import logger | |
from llm_engineering import domain | |
from llm_engineering.application import utils | |
from llm_engineering.domain.cleaned_documents import CleanedDocument | |
from llm_engineering.domain.dataset import DatasetType, TrainTestSplit | |
from llm_engineering.domain.prompt import GenerateDatasetSamplesPrompt, Prompt | |
from llm_engineering.domain.types import DataCategory | |
from llm_engineering.settings import settings | |
from . import constants | |
from . import utils as generation_utils | |
from .output_parsers import ListPydanticOutputParser | |
class DatasetGenerator(ABC): | |
tokenizer = tiktoken.encoding_for_model(settings.OPENAI_MODEL_ID) | |
dataset_type: DatasetType | None = None | |
system_prompt_template = """You are a helpful assistant who generates {dataset_format} based on the given context. \ | |
Provide your response in JSON format. | |
""" | |
prompt_template_str: str | None = None | |
def get_system_prompt(cls) -> Prompt: | |
assert cls.dataset_type is not None, "Dataset type must be set before calling get_system_prompt()" | |
dataset_format = ( | |
"instruction-answer pairs" if cls.dataset_type == DatasetType.INSTRUCTION else "instruction-answer triples" | |
) | |
input_variables = { | |
"dataset_format": dataset_format, | |
} | |
system_prompt = cls.system_prompt_template.format(**input_variables) | |
return Prompt( | |
template=cls.system_prompt_template, | |
input_variables=input_variables, | |
content=system_prompt, | |
) | |
def get_prompts(cls, documents: list[CleanedDocument]) -> dict[DataCategory, list[GenerateDatasetSamplesPrompt]]: | |
documents = generation_utils.extract_substrings(documents) | |
grouped_prompts = {} | |
grouped_cleaned_documents = CleanedDocument.group_by_category(documents) | |
for category, category_documents in grouped_cleaned_documents.items(): | |
category_prompts = [cls.get_prompt(document) for document in category_documents] | |
grouped_prompts[category] = category_prompts | |
return grouped_prompts | |
def get_prompt(cls, document: CleanedDocument) -> GenerateDatasetSamplesPrompt: | |
assert cls.prompt_template_str is not None, "Prompt template must be set before calling get_prompt()" | |
data_category = document.get_category() | |
prompt_template = PromptTemplate.from_template( | |
template=cls.prompt_template_str, | |
template_format="jinja2", | |
) | |
input_variables = { | |
"extract": document.content, | |
} | |
prompt = prompt_template.format(**input_variables) | |
prompt_tokens = cls.tokenizer.encode(prompt) | |
if len(prompt_tokens) > settings.OPENAI_MAX_TOKEN_WINDOW: | |
prompt_tokens = prompt_tokens[: settings.OPENAI_MAX_TOKEN_WINDOW] | |
prompt = cls.tokenizer.decode(prompt_tokens) | |
prompt = GenerateDatasetSamplesPrompt( | |
template=prompt_template.template, | |
input_variables=input_variables, | |
content=prompt, | |
num_tokens=len(prompt_tokens), | |
data_category=data_category, | |
document=document, | |
) | |
return prompt | |
def generate( | |
cls, | |
prompts: dict[DataCategory, list[GenerateDatasetSamplesPrompt]], | |
test_size: float = 0.2, | |
mock: bool = False, | |
) -> TrainTestSplit: | |
assert cls.dataset_type is not None, "Dataset type must be set before calling generate()" | |
def _to_langchain( | |
prompt: GenerateDatasetSamplesPrompt, | |
) -> list[BaseMessage]: | |
messages = [ | |
SystemMessage(content=cls.get_system_prompt().content), | |
HumanMessage(content=prompt.content), | |
] | |
return messages | |
if mock: | |
llm = FakeListLLM(responses=[constants.get_mocked_response(cls.dataset_type)]) | |
else: | |
llm = ChatOllama( | |
model=settings.LLAMA_MODEL_ID, | |
max_tokens=2000 if cls.dataset_type == DatasetType.PREFERENCE else 1200, | |
temperature=0.7, | |
) | |
parser = ListPydanticOutputParser(pydantic_object=cls._get_dataset_sample_type()) | |
chain = llm | parser | |
datasets = {} | |
for category, category_prompts in prompts.items(): | |
langchain_category_prompts = [_to_langchain(prompt) for prompt in category_prompts] | |
batches = utils.misc.batch(langchain_category_prompts, size=24) | |
flattened_instruct_dataset_samples = [] | |
for batch in batches: | |
try: | |
batched_dataset_samples = chain.batch(batch, stop=None) | |
for instruct_dataset_sample_batch in batched_dataset_samples: | |
flattened_instruct_dataset_samples.extend(instruct_dataset_sample_batch) | |
except OutputParserException: | |
logger.exception(f"Failed to parse the output JSON for a batch for category {category}") | |
dataset = domain.dataset.build_dataset( | |
dataset_type=cls.dataset_type, category=category, samples=flattened_instruct_dataset_samples | |
) | |
datasets[category] = dataset | |
logger.info(f"Generated {len(dataset.samples)} samples for category '{category}'.") | |
processed_datasets = cls.post_process_datasets(datasets, test_size=test_size) | |
return processed_datasets | |
def _get_dataset_sample_type( | |
cls, | |
) -> type[domain.dataset.InstructDatasetSample] | type[domain.dataset.PreferenceDatasetSample]: | |
return ( | |
domain.dataset.InstructDatasetSample | |
if cls.dataset_type == DatasetType.INSTRUCTION | |
else domain.dataset.PreferenceDatasetSample | |
) | |
def post_process_datasets( | |
cls, datasets: dict[DataCategory, domain.dataset.InstructDataset], test_size: float | |
) -> TrainTestSplit: | |
pass | |
class InstructionDatasetGenerator(DatasetGenerator): | |
dataset_type = DatasetType.INSTRUCTION | |
prompt_template_str = """Based on the following extract, generate five instruction-answer pairs. Each instruction \ | |
must ask to write about a specific topic contained in the context. Each answer \ | |
must provide a relevant paragraph based on the information found in the \ | |
context. Only use concepts from the context to generate the instructions. \ | |
Instructions must never explicitly mention a context, a system, a course, or an extract. \ | |
Instructions must be self-contained and general. \ | |
Answers must imitate the writing style of the context. \ | |
Example instruction: Explain the concept of an LLM Twin. \ | |
Example answer: An LLM Twin is essentially an AI character that mimics your writing style, personality, and voice. \ | |
It's designed to write just like you by incorporating these elements into a language model. \ | |
The idea is to create a digital replica of your writing habits using advanced AI techniques. \ | |
Structure the answer in JSON format, ready to be loaded in Python by json.loads(), as a list of objects. | |
Do not add any extra characters and provide your response in JSON format with the following structure: | |
[ | |
{"instruction": "...", "answer": "..."}, | |
... | |
] | |
Extract: | |
{extract} | |
""" | |
def post_process_datasets( | |
cls, datasets: dict[DataCategory, domain.dataset.InstructDataset], test_size: float | |
) -> TrainTestSplit: | |
train_test_split = generation_utils.create_instruct_train_test_split( | |
datasets, test_size=test_size, random_state=42 | |
) | |
return train_test_split | |
class PreferenceDatasetGenerator(DatasetGenerator): | |
dataset_type = DatasetType.PREFERENCE | |
prompt_template_str = """Based on the following extract, generate five instruction-answer triples. Each triple should consist of: | |
1. An instruction asking about a specific topic in the context. | |
2. A generated answer that attempts to answer the instruction based on the context, named as 'rejected'. | |
3. An extracted answer that is a relevant excerpt directly from the given context, named as 'chosen'. | |
Instructions must be self-contained and general, without explicitly mentioning a context, system, course, or extract. | |
Important: | |
- Ensure that the extracted answer, the chosen one, is a verbatim copy from the context, including all punctuation and apostrophes. | |
- Do not add any ellipsis (...) or [...] to indicate skipped text in the extracted answer. | |
- If the relevant text is not continuous, use two separate sentences from the context instead of skipping text. | |
Structure the answer in JSON format, ready to be loaded in Python by json.loads(), as a list of objects. | |
Do not add any extra characters and provide your response in JSON format with the following structure: | |
[ | |
{ | |
"instruction": "...", | |
"rejected": "...", | |
"chosen": "..." | |
}, | |
... | |
] | |
Extract: | |
{extract} | |
""" | |
def post_process_datasets( | |
cls, datasets: dict[DataCategory, domain.dataset.PreferenceDataset], test_size: float | |
) -> TrainTestSplit: | |
datasets = generation_utils.filter_short_answers(datasets) | |
datasets = generation_utils.filter_answer_format(datasets) | |
remaining_samples = sum([dataset.num_samples for dataset in datasets.values()]) | |
logger.info( | |
f"Filtered out short answers and answers with incorrect format. Remaining samples: {remaining_samples}" | |
) | |
train_test_split = generation_utils.create_preference_train_test_split( | |
datasets, test_size=test_size, random_state=42 | |
) | |
return train_test_split | |
def get_dataset_generator(dataset_type: DatasetType) -> type[DatasetGenerator]: | |
if dataset_type == DatasetType.INSTRUCTION: | |
return InstructionDatasetGenerator | |
elif dataset_type == DatasetType.PREFERENCE: | |
return PreferenceDatasetGenerator | |
else: | |
raise ValueError(f"Invalid dataset type: {dataset_type}") | |