File size: 3,771 Bytes
23577aa
bca5017
76e0d03
9811ddc
93bac14
1989656
 
23577aa
bca5017
 
 
 
 
76e0d03
7b7129f
76e0d03
 
 
 
 
 
23577aa
8f3a10f
da2277a
1989656
 
 
 
 
 
 
9811ddc
 
 
 
 
 
 
 
 
 
 
 
 
578932f
23577aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1989656
8f3a10f
9811ddc
 
 
578932f
bca5017
 
 
 
 
 
 
 
 
 
23577aa
 
 
 
 
 
 
1dd768c
23577aa
 
 
 
 
8cba427
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from huggingface_hub import InferenceClient
from pathlib import Path
from transformers import RagTokenForGeneration, RagTokenizer
import faiss
from typing import List
from pdfplumber import open as open_pdf

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Load the PDF file
pdf_path = Path("apexcustoms.pdf")
with open_pdf(pdf_path) as pdf:
    text = "\n".join(page.extract_text() for page in pdf.pages)

# Split the PDF text into chunks
chunk_size = 1000  # Adjust this value based on your needs
text_chunks: List[str] = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

# Load the RAG model and tokenizer for retrieval
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")

# Create an in-memory index using the text chunks
embeddings = rag_model.question_encoder(rag_tokenizer(text_chunks, padding=True, return_tensors="pt")["input_ids"])
index = faiss.IndexFlatL2(embeddings.size(-1))
index.add(embeddings.detach().numpy())

# Custom retriever class
class CustomRetriever:
    def __init__(self, documents, embeddings, index):
        self.documents = documents
        self.embeddings = embeddings
        self.index = index

    def get_relevant_docs(self, query_embeddings, top_k=4):
        scores, doc_indices = self.index.search(query_embeddings.detach().numpy(), top_k)
        return [(self.documents[doc_idx], score) for doc_idx, score in zip(doc_indices[0], scores[0])]

# Create a custom retriever instance
retriever = CustomRetriever(text_chunks, embeddings, index)

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    # Retrieve relevant chunks using the custom retriever
    rag_input_ids = rag_tokenizer(message, return_tensors="pt").input_ids
    query_embeddings = rag_model.question_encoder(rag_input_ids)
    relevant_docs = retriever.get_relevant_docs(query_embeddings)
    retrieved_text = "\n".join([doc for doc, _ in relevant_docs])

    # Generate the response using the zephyr model
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        files={"context": retrieved_text},  # Pass retrieved text as context
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be conversational in your responses. You should remember the user car model and tailor your answers accordingly. You limit yourself to answering the given question and maybe propose a suggestion but not write the next question of the user. \n\nUser: ", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()