File size: 3,263 Bytes
588f69f
 
4cce6fa
edc2346
 
 
4cce6fa
edc2346
4cce6fa
 
 
edc2346
 
 
 
 
 
 
 
 
 
daeb152
 
588f69f
 
69e3377
588f69f
 
 
 
 
 
edc2346
588f69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daeb152
4cce6fa
 
daeb152
 
 
 
 
 
 
 
 
 
 
4cce6fa
edc2346
588f69f
 
 
 
 
 
 
 
4cce6fa
588f69f
 
 
 
edc2346
588f69f
 
 
 
 
 
 
69e3377
588f69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import RagTokenizer, RagTokenForGeneration
from typing import List, Dict, Tuple
import re
import os
import torch

# Load the RAG model and tokenizer
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")

# Load your PDF document
pdf_path = "apexcustoms.pdf"
with open(pdf_path, 'rb') as f:
    pdf_text = f.read().decode('utf-8', errors='ignore')

# Split the PDF text into chunks
split_pattern = r'\n\n'
doc_chunks = re.split(split_pattern, pdf_text)

# Preprocess the corpus
corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids

"""
For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")


def respond(
    message,
    history: List[Tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Tokenize the input
    inputs = rag_tokenizer(message, return_tensors="pt")
    input_ids = inputs.pop("input_ids")

    # Generate with the RAG model
    output_ids = rag_model.generate(
        input_ids=input_ids,
        context_input_ids=corpus,
        max_length=max_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=0,
        num_beams=2,
    )
    retrieved_context = rag_tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
        context=retrieved_context,  # Include the retrieved context
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be creative and conversational in your responses. You should remember the user car model and tailor your answers accordingly. \n\nUser: ", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()