|
from typing import List |
|
|
|
from haystack.dataclasses import ChatMessage |
|
from pypdf import PdfReader |
|
from haystack.utils import Secret |
|
from haystack import Pipeline, Document, component |
|
|
|
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter |
|
from haystack.components.writers import DocumentWriter |
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder |
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever |
|
from haystack.components.builders import DynamicChatPromptBuilder |
|
from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator |
|
from haystack.document_stores.types import DuplicatePolicy |
|
|
|
SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
MAX_TOKENS = 500 |
|
|
|
template = """ |
|
As a professional HR recruiter given the following information, answer the question shortly and concisely in 1 or 2 sentences. |
|
|
|
Context: |
|
{% for document in documents %} |
|
{{ document.content }} |
|
{% endfor %} |
|
|
|
Question: {{question}} |
|
Answer: |
|
""" |
|
|
|
|
|
@component |
|
class UploadedFileConverter: |
|
""" |
|
A component to convert uploaded PDF files to Documents |
|
""" |
|
|
|
@component.output_types(documents=List[Document]) |
|
def run(self, uploaded_file): |
|
pdf = PdfReader(uploaded_file) |
|
documents = [] |
|
|
|
name = uploaded_file.name.rstrip('.PDF') + '_' |
|
for page in pdf.pages: |
|
documents.append( |
|
Document( |
|
content=page.extract_text(), |
|
meta={'name': name + f"_{page.page_number}"})) |
|
return {"documents": documents} |
|
|
|
|
|
def create_ingestion_pipeline(document_store): |
|
doc_embedder = SentenceTransformersDocumentEmbedder(model=SENTENCE_RETREIVER_MODEL) |
|
doc_embedder.warm_up() |
|
|
|
pipeline = Pipeline() |
|
pipeline.add_component("converter", UploadedFileConverter()) |
|
pipeline.add_component("cleaner", DocumentCleaner()) |
|
pipeline.add_component("splitter", |
|
DocumentSplitter(split_by="passage", split_length=100, split_overlap=10)) |
|
pipeline.add_component("embedder", doc_embedder) |
|
pipeline.add_component("writer", |
|
DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE)) |
|
|
|
pipeline.connect("converter", "cleaner") |
|
pipeline.connect("cleaner", "splitter") |
|
pipeline.connect("splitter", "embedder") |
|
pipeline.connect("embedder", "writer") |
|
return pipeline |
|
|
|
|
|
def create_inference_pipeline(document_store, model_name, api_key): |
|
if model_name == "local LLM": |
|
generator = OpenAIChatGenerator(api_key=Secret.from_token("<local LLM doesn't need an API key>"), |
|
model=model_name, |
|
api_base_url="http://localhost:1234/v1", |
|
generation_kwargs={"max_tokens": MAX_TOKENS}, |
|
) |
|
elif "gpt" in model_name: |
|
generator = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model=model_name, |
|
generation_kwargs={"max_tokens": MAX_TOKENS, "temperature": 0}, |
|
streaming_callback=lambda chunk: print(chunk.content, end="", flush=True), |
|
|
|
) |
|
else: |
|
generator = HuggingFaceTGIChatGenerator(token=Secret.from_token(api_key), model=model_name, |
|
generation_kwargs={"max_new_tokens": MAX_TOKENS} |
|
) |
|
pipeline = Pipeline() |
|
pipeline.add_component("text_embedder", |
|
SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL)) |
|
pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3)) |
|
pipeline.add_component("prompt_builder", |
|
DynamicChatPromptBuilder(runtime_variables=["query", "documents"])) |
|
pipeline.add_component("llm", generator) |
|
pipeline.connect("text_embedder.embedding", "retriever.query_embedding") |
|
pipeline.connect("retriever.documents", "prompt_builder.documents") |
|
pipeline.connect("prompt_builder.prompt", "llm.messages") |
|
|
|
return pipeline |
|
|
|
|
|
class DocumentQAEngine: |
|
def __init__(self, |
|
model_name, |
|
api_key=None |
|
): |
|
self.api_key = api_key |
|
self.model_name = model_name |
|
document_store = InMemoryDocumentStore() |
|
self.chunks = [] |
|
self.inference_pipeline = create_inference_pipeline(document_store, model_name, api_key) |
|
self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store) |
|
|
|
def ingest_pdf(self, uploaded_file): |
|
self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}}) |
|
|
|
def inference(self, query, input_messages: List[dict]): |
|
system_message = ChatMessage.from_system( |
|
"You are a professional analyzer of git repos, having access to the repo content. In 1-3 sentences") |
|
messages = [system_message] |
|
for message in input_messages: |
|
if message["role"] == "user": |
|
messages.append(ChatMessage.from_system(message["content"])) |
|
else: |
|
messages.append( |
|
ChatMessage.from_user(message["content"])) |
|
messages.append(ChatMessage.from_user(""" |
|
Relevant information from the uploaded repo: |
|
{% for doc in documents %} |
|
{{ doc.content }} |
|
{% endfor %} |
|
|
|
\nQuestion: {{query}} |
|
\nAnswer: |
|
""")) |
|
res = self.inference_pipeline.run(data={"text_embedder": {"text": query}, |
|
"prompt_builder": {"prompt_source": messages, |
|
"query": query |
|
}}) |
|
return res["llm"]["replies"][0].content |
|
|