Spaces:
Build error
Build error
File size: 10,693 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
from abc import ABC, abstractmethod
import tiktoken
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models.fake import FakeListLLM
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompts import PromptTemplate
from langchain_ollama import ChatOllama
from loguru import logger
from llm_engineering import domain
from llm_engineering.application import utils
from llm_engineering.domain.cleaned_documents import CleanedDocument
from llm_engineering.domain.dataset import DatasetType, TrainTestSplit
from llm_engineering.domain.prompt import GenerateDatasetSamplesPrompt, Prompt
from llm_engineering.domain.types import DataCategory
from llm_engineering.settings import settings
from . import constants
from . import utils as generation_utils
from .output_parsers import ListPydanticOutputParser
class DatasetGenerator(ABC):
tokenizer = tiktoken.encoding_for_model(settings.OPENAI_MODEL_ID)
dataset_type: DatasetType | None = None
system_prompt_template = """You are a helpful assistant who generates {dataset_format} based on the given context. \
Provide your response in JSON format.
"""
prompt_template_str: str | None = None
@classmethod
def get_system_prompt(cls) -> Prompt:
assert cls.dataset_type is not None, "Dataset type must be set before calling get_system_prompt()"
dataset_format = (
"instruction-answer pairs" if cls.dataset_type == DatasetType.INSTRUCTION else "instruction-answer triples"
)
input_variables = {
"dataset_format": dataset_format,
}
system_prompt = cls.system_prompt_template.format(**input_variables)
return Prompt(
template=cls.system_prompt_template,
input_variables=input_variables,
content=system_prompt,
)
@classmethod
def get_prompts(cls, documents: list[CleanedDocument]) -> dict[DataCategory, list[GenerateDatasetSamplesPrompt]]:
documents = generation_utils.extract_substrings(documents)
grouped_prompts = {}
grouped_cleaned_documents = CleanedDocument.group_by_category(documents)
for category, category_documents in grouped_cleaned_documents.items():
category_prompts = [cls.get_prompt(document) for document in category_documents]
grouped_prompts[category] = category_prompts
return grouped_prompts
@classmethod
def get_prompt(cls, document: CleanedDocument) -> GenerateDatasetSamplesPrompt:
assert cls.prompt_template_str is not None, "Prompt template must be set before calling get_prompt()"
data_category = document.get_category()
prompt_template = PromptTemplate.from_template(
template=cls.prompt_template_str,
template_format="jinja2",
)
input_variables = {
"extract": document.content,
}
prompt = prompt_template.format(**input_variables)
prompt_tokens = cls.tokenizer.encode(prompt)
if len(prompt_tokens) > settings.OPENAI_MAX_TOKEN_WINDOW:
prompt_tokens = prompt_tokens[: settings.OPENAI_MAX_TOKEN_WINDOW]
prompt = cls.tokenizer.decode(prompt_tokens)
prompt = GenerateDatasetSamplesPrompt(
template=prompt_template.template,
input_variables=input_variables,
content=prompt,
num_tokens=len(prompt_tokens),
data_category=data_category,
document=document,
)
return prompt
@classmethod
def generate(
cls,
prompts: dict[DataCategory, list[GenerateDatasetSamplesPrompt]],
test_size: float = 0.2,
mock: bool = False,
) -> TrainTestSplit:
assert cls.dataset_type is not None, "Dataset type must be set before calling generate()"
def _to_langchain(
prompt: GenerateDatasetSamplesPrompt,
) -> list[BaseMessage]:
messages = [
SystemMessage(content=cls.get_system_prompt().content),
HumanMessage(content=prompt.content),
]
return messages
if mock:
llm = FakeListLLM(responses=[constants.get_mocked_response(cls.dataset_type)])
else:
llm = ChatOllama(
model=settings.LLAMA_MODEL_ID,
max_tokens=2000 if cls.dataset_type == DatasetType.PREFERENCE else 1200,
temperature=0.7,
)
parser = ListPydanticOutputParser(pydantic_object=cls._get_dataset_sample_type())
chain = llm | parser
datasets = {}
for category, category_prompts in prompts.items():
langchain_category_prompts = [_to_langchain(prompt) for prompt in category_prompts]
batches = utils.misc.batch(langchain_category_prompts, size=24)
flattened_instruct_dataset_samples = []
for batch in batches:
try:
batched_dataset_samples = chain.batch(batch, stop=None)
for instruct_dataset_sample_batch in batched_dataset_samples:
flattened_instruct_dataset_samples.extend(instruct_dataset_sample_batch)
except OutputParserException:
logger.exception(f"Failed to parse the output JSON for a batch for category {category}")
dataset = domain.dataset.build_dataset(
dataset_type=cls.dataset_type, category=category, samples=flattened_instruct_dataset_samples
)
datasets[category] = dataset
logger.info(f"Generated {len(dataset.samples)} samples for category '{category}'.")
processed_datasets = cls.post_process_datasets(datasets, test_size=test_size)
return processed_datasets
@classmethod
def _get_dataset_sample_type(
cls,
) -> type[domain.dataset.InstructDatasetSample] | type[domain.dataset.PreferenceDatasetSample]:
return (
domain.dataset.InstructDatasetSample
if cls.dataset_type == DatasetType.INSTRUCTION
else domain.dataset.PreferenceDatasetSample
)
@classmethod
@abstractmethod
def post_process_datasets(
cls, datasets: dict[DataCategory, domain.dataset.InstructDataset], test_size: float
) -> TrainTestSplit:
pass
class InstructionDatasetGenerator(DatasetGenerator):
dataset_type = DatasetType.INSTRUCTION
prompt_template_str = """Based on the following extract, generate five instruction-answer pairs. Each instruction \
must ask to write about a specific topic contained in the context. Each answer \
must provide a relevant paragraph based on the information found in the \
context. Only use concepts from the context to generate the instructions. \
Instructions must never explicitly mention a context, a system, a course, or an extract. \
Instructions must be self-contained and general. \
Answers must imitate the writing style of the context. \
Example instruction: Explain the concept of an LLM Twin. \
Example answer: An LLM Twin is essentially an AI character that mimics your writing style, personality, and voice. \
It's designed to write just like you by incorporating these elements into a language model. \
The idea is to create a digital replica of your writing habits using advanced AI techniques. \
Structure the answer in JSON format, ready to be loaded in Python by json.loads(), as a list of objects.
Do not add any extra characters and provide your response in JSON format with the following structure:
[
{"instruction": "...", "answer": "..."},
...
]
Extract:
{extract}
"""
@classmethod
def post_process_datasets(
cls, datasets: dict[DataCategory, domain.dataset.InstructDataset], test_size: float
) -> TrainTestSplit:
train_test_split = generation_utils.create_instruct_train_test_split(
datasets, test_size=test_size, random_state=42
)
return train_test_split
class PreferenceDatasetGenerator(DatasetGenerator):
dataset_type = DatasetType.PREFERENCE
prompt_template_str = """Based on the following extract, generate five instruction-answer triples. Each triple should consist of:
1. An instruction asking about a specific topic in the context.
2. A generated answer that attempts to answer the instruction based on the context, named as 'rejected'.
3. An extracted answer that is a relevant excerpt directly from the given context, named as 'chosen'.
Instructions must be self-contained and general, without explicitly mentioning a context, system, course, or extract.
Important:
- Ensure that the extracted answer, the chosen one, is a verbatim copy from the context, including all punctuation and apostrophes.
- Do not add any ellipsis (...) or [...] to indicate skipped text in the extracted answer.
- If the relevant text is not continuous, use two separate sentences from the context instead of skipping text.
Structure the answer in JSON format, ready to be loaded in Python by json.loads(), as a list of objects.
Do not add any extra characters and provide your response in JSON format with the following structure:
[
{
"instruction": "...",
"rejected": "...",
"chosen": "..."
},
...
]
Extract:
{extract}
"""
@classmethod
def post_process_datasets(
cls, datasets: dict[DataCategory, domain.dataset.PreferenceDataset], test_size: float
) -> TrainTestSplit:
datasets = generation_utils.filter_short_answers(datasets)
datasets = generation_utils.filter_answer_format(datasets)
remaining_samples = sum([dataset.num_samples for dataset in datasets.values()])
logger.info(
f"Filtered out short answers and answers with incorrect format. Remaining samples: {remaining_samples}"
)
train_test_split = generation_utils.create_preference_train_test_split(
datasets, test_size=test_size, random_state=42
)
return train_test_split
def get_dataset_generator(dataset_type: DatasetType) -> type[DatasetGenerator]:
if dataset_type == DatasetType.INSTRUCTION:
return InstructionDatasetGenerator
elif dataset_type == DatasetType.PREFERENCE:
return PreferenceDatasetGenerator
else:
raise ValueError(f"Invalid dataset type: {dataset_type}")
|