File size: 5,234 Bytes
7279c69
055baa9
aeb1340
edb320d
 
 
aeb1340
 
edb320d
 
 
 
 
aeb1340
 
 
 
9994f95
 
 
f16ca94
 
 
 
 
 
 
 
055baa9
 
 
 
 
 
 
 
 
 
 
 
aeb1340
 
 
 
 
 
 
edb320d
aeb1340
 
 
edb320d
 
 
 
 
aeb1340
edb320d
aeb1340
 
9994f95
 
aeb1340
9994f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515f14b
055baa9
c0a84b3
 
 
055baa9
515f14b
055baa9
 
 
 
 
 
1f5b72f
c0a84b3
055baa9
c0a84b3
055baa9
 
 
 
 
 
515f14b
055baa9
 
515f14b
 
 
 
 
 
 
 
 
 
 
 
aeb1340
098591c
aeb1340
 
 
cdb865c
aeb1340
 
9994f95
aeb1340
 
515f14b
 
 
f16ca94
515f14b
 
aeb1340
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import random
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


# Initialize the FAISS vector store
vector_store = None

# Sample PDF file
sample_filename = "Attention Is All You Need.pdf"

examples_questions = [["What is Transformer?"],
            ["What is Attention?"],
            ["What is Scaled Dot-Product Attention?"],
            ["What are Encoder and Decoder?"],
            ["Describe more about the Transformer."],
            ["Why use self-attention?"],
           ]

template = \
"""Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Always say "Thanks for asking!" at the end of the answer.

{context}

Question: {question}

Answer:
"""

# Function to handle PDF upload and indexing
def index_pdf(pdf):
    global vector_store
    
    # Load the PDF
    loader = PyPDFLoader(pdf.name)
    documents = loader.load()

    # Split the documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)

    # Embed the chunks 
    embeddings = HuggingFaceEmbeddings(model_name="bert-base-uncased", encode_kwargs={"normalize_embeddings": True})

    # Store the embeddings in the vector store
    vector_store = FAISS.from_documents(texts, embeddings)

    return "PDF indexed successfully!"

def load_sample_pdf():
    global vector_store
    
    # Load the PDF
    loader = PyPDFLoader(sample_filename)
    documents = loader.load()

    # Split the documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)

    # Embed the chunks 
    embeddings = HuggingFaceEmbeddings(model_name="bert-base-uncased", encode_kwargs={"normalize_embeddings": True})

    # Store the embeddings in the vector store
    vector_store = FAISS.from_documents(texts, embeddings)

    return "Sample PDF indexed successfully!"


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)
    

def generate_response(query, history, model, temperature, max_tokens, top_p, seed):
    if vector_store is None:
        return "Please upload and index a PDF at the Indexing tab."

    if seed == 0:
        seed = random.randint(1, 100000)

    retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 16})
    llm = ChatGroq(groq_api_key=os.environ.get("GROQ_API_KEY"), model=model)
    custom_rag_prompt = PromptTemplate.from_template(template)
   
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | custom_rag_prompt
        | llm
        | StrOutputParser()
    )

    response = rag_chain.invoke(query)
   
    return response



additional_inputs = [
    gr.Dropdown(choices=["llama-3.1-70b-versatile", "llama-3.1-8b-instant", "llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma2-9b-it", "gemma-7b-it"], value="llama-3.1-70b-versatile", label="Model"),
    gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Temperature", info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative."),
    gr.Slider(minimum=1, maximum=8000, step=1, value=8000, label="Max Tokens", info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b, 132k for llama 3.1."),
    gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Top P", info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p."),
    gr.Number(precision=0, value=0, label="Seed", info="A starting point to initiate generation, use 0 for random")
]

# Create the Gradio interface
with gr.Blocks(theme="Nymbo/Alyx_Theme") as demo:
    with gr.Tab("Indexing"):
        pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
        index_button = gr.Button("Index PDF")
        load_sample = gr.Button("Alternatively, Load and Index [Attention Is All You Need.pdf] as a Sample")
        index_output = gr.Textbox(label="Indexing Status")
        index_button.click(index_pdf, inputs=pdf_input, outputs=index_output)
        load_sample.click(load_sample_pdf, inputs=None, outputs=index_output)
    
    with gr.Tab("Chatbot"):
        gr.ChatInterface(
            fn=generate_response, 
            chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
            examples=examples_questions,
            additional_inputs=additional_inputs,
        )       

# Launch the Gradio app
demo.launch()