File size: 5,043 Bytes
77a3002
 
 
 
 
 
 
 
 
 
 
1008ab2
65721af
10e5e54
1008ab2
77a3002
ac83c98
77a3002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b26a9
d9640db
 
ac83c98
77a3002
 
5fd0a46
8489a3a
 
77a3002
 
65721af
 
 
 
 
 
10e5e54
 
8489a3a
52f9aa7
 
 
77a3002
421861c
d9640db
 
 
1fc2033
f38df4a
 
d9640db
52f9aa7
8489a3a
d9640db
 
a127868
d9640db
 
8489a3a
d9640db
f38df4a
d9640db
 
fd89e21
d9640db
 
 
 
73de17b
fd89e21
d9640db
 
 
4dc6167
d9640db
77ab3f6
d9640db
 
 
 
 
 
 
 
77a3002
 
 
 
 
 
 
 
 
 
 
d9640db
1008ab2
77a3002
 
 
 
 
 
 
 
 
 
d9640db
 
77a3002
93cde06
f0b26a9
77a3002
 
f0b26a9
77a3002
 
8489a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import yaml
import fitz
import torch
import gradio as gr
from PIL import Image
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader
from langchain.prompts import PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import spaces
from langchain_text_splitters import CharacterTextSplitter,RecursiveCharacterTextSplitter


class PDFChatBot:
    def __init__(self, config_path="config.yaml"):
        """
        Initialize the PDFChatBot instance.

        Parameters:
            config_path (str): Path to the configuration file (default is "../config.yaml").
        """
        self.processed = False
        self.page = 0
        self.chat_history = []
        # Initialize other attributes to None
        self.prompt = None
        self.documents = None
        self.embeddings = None
        self.vectordb = None
        self.tokenizer = None
        self.model = None
        self.pipeline = None
        self.chain = None
        self.chunk_size = None
        self.current_context = None
        self.format_seperator="""\n\n--\n\n"""
        #self.chunk_size_slider = chunk_size_slider

    def load_embeddings(self):

        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        print("Embedding model loaded")

    def load_vectordb(self):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=256,
            chunk_overlap=100,
            length_function=len,
            add_start_index=True,
        )
        docs = text_splitter.split_documents(self.documents)
        self.vectordb = Chroma.from_documents(docs, self.embeddings)
        print("Vector store created")
    @spaces.GPU
    def load_tokenizer(self):
        self.tokenizer = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

    @spaces.GPU
    def create_organic_pipeline(self):
        self.pipeline = pipeline(
            "text-generation",
            model="meta-llama/Meta-Llama-3-8B-Instruct",
            model_kwargs={"torch_dtype": torch.bfloat16},
            device="cuda",
        )
        self.load_tokenizer()
        print("Model pipeline loaded")

    def get_organic_context(self, query):
        documents = self.vectordb.similarity_search_with_relevance_scores(query, k=3)
        context = self.format_seperator.join([doc.page_content for doc, score in documents])
        self.current_context = context
        print("Context Ready")
        print(self.current_context)
    @spaces.GPU
    def create_organic_response(self, history, query):
        self.get_organic_context(query)
        tokenizer = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
        messages = [
            {"role": "system", "content": "From the the contained given below, answer the question of user \n " + self.current_context},
            {"role": "user", "content": query},
        ]

        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        temp = 0.1
        outputs = self.pipeline(
            prompt,
            max_new_tokens=1024,
            do_sample=True,
            temperature=temp,
            top_p=0.9,
        )
        return outputs[0]["generated_text"][len(prompt):]


    def process_file(self, file):
        """
        Process the uploaded PDF file and initialize necessary components: Tokenizer, VectorDB and LLM.

        Parameters:
            file (FileStorage): The uploaded PDF file.
        """
        self.documents = PyPDFLoader(file.name).load()
        self.load_embeddings()
        self.load_vectordb()
        #self.create_chain()
    @spaces.GPU
    def generate_response(self, history, query, file):

        if not query:
            raise gr.Error(message='Submit a question')
        if not file:
            raise gr.Error(message='Upload a PDF')
        if not self.processed:
            self.process_file(file)
            self.processed = True

        result = self.create_organic_response(history="",query=query)
        return result,""

    def render_file(self, file,chunk_size):
        print(chunk_size)
        doc = fitz.open(file.name)
        page = doc[self.page]
        self.chunk_size = chunk_size
        pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
        image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
        return image

    def add_text(self, history, text):
        """
        Add user-entered text to the chat history.
        Parameters:
            history (list): List of chat history tuples.
            text (str): User-entered text.
        Returns:
            list: Updated chat history.
        """
        if not text:
            raise gr.Error('Enter text')
        history.append((text, ''))
        return history