File size: 3,521 Bytes
b6c1552
 
90592aa
ac6e794
b6c1552
 
 
ac6e794
4272192
acd8e5c
 
 
b6c1552
 
 
 
 
 
 
 
dcfda02
b6c1552
 
 
ac6e794
b6c1552
 
 
 
 
 
acd8e5c
b6c1552
 
 
 
 
 
acd8e5c
b6c1552
 
 
 
 
ac6e794
 
 
 
 
f145553
219a131
acd8e5c
ac6e794
acd8e5c
 
 
 
34e179e
ac6e794
 
90592aa
 
 
 
 
 
 
 
 
219a131
ac6e794
 
 
 
 
 
67535ac
 
219a131
90592aa
6a0b151
9765d26
 
b6c1552
03d2ba6
b6c1552
 
cab9ecc
 
 
90592aa
9765d26
90592aa
 
 
ac6e794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import json
import gradio as gr
import faiss
import fitz  # PyMuPDF
import numpy as np
from huggingface_hub import InferenceClient
from sentence_transformers import SentenceTransformer

# Initialize the SentenceTransformer model
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Extract text from PDF
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page_num in range(doc.page_count):
        page = doc.load_page(page_num)
        text += page.get_text()
    return text.split("\n\n")

# Build FAISS index
def build_faiss_index(documents):
    document_embeddings = model.encode(documents)

    index = faiss.IndexFlatL2(document_embeddings.shape[1])
    index.add(document_embeddings)

    faiss.write_index(index, "apexcustoms_index.faiss")
    model.save("sentence_transformer_model")

    return index

# Ensure that text extraction and FAISS index building is done
if not os.path.exists("apexcustoms_index.faiss") or not os.path.exists("sentence_transformer_model"):
    documents = extract_text_from_pdf("apexcustoms.pdf")
    with open("apexcustoms.json", "w") as f:
        json.dump(documents, f)
    index = build_faiss_index(documents)
else:
    index = faiss.read_index("apexcustoms_index.faiss")

# Hugging Face client
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

def retrieve_documents(query, k=5):
    query_embedding = model.encode([query])
    distances, indices = index.search(query_embedding, k)
    return [documents[i] for i in indices[0]]

async def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Retrieve relevant documents
    relevant_docs = retrieve_documents(message)
    context = "\n\n".join(relevant_docs[:3])  # Limit context to top 3 documents

    # Limit history to the last 5 exchanges to reduce payload size
    history = history[-5:]

    messages = [{"role": "system", "content": system_message},
                {"role": "user", "content": f"Context: {context}\n\n{message}"}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    async for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        if message.choices and message.choices[0].delta and message.choices[0].delta.content:
            token = message.choices[0].delta.content
            yield token

demo = gr.ChatInterface(
    fn=respond,
    inputs=[
        gr.Textbox(
            value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be creative and conversational in your responses. You should remember the user car model and tailor your answers accordingly. \n\nUser: ",
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=2048, step=1, value=512, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, step=0.1, value=0.7, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)"),
    ],
    outputs=gr.Textbox(label="Assistant's Response"),
)

if __name__ == "__main__":
    demo.launch()