TRaw commited on
Commit
dc3afaa
·
1 Parent(s): b301e28

Upload 5 files

Browse files
src/gradio_chatbot.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model_utils import *
3
+
4
+
5
+ with gr.Blocks(gr.themes.Soft(primary_hue=gr.themes.colors.slate, secondary_hue=gr.themes.colors.purple)) as demo:
6
+ gr.Markdown('''# Retrieval Augmented Generation \n
7
+ RAG (Retrieval-Augmented Generation) addresses the data freshness problem in Large Language Models (LLMs) like Llama-2, which lack awareness of recent events. LLMs perceive the world only through their training data, leading to challenges when needing up-to-date information or specific datasets. To tackle this, retrieval augmentation is employed, enabling relevant external knowledge from a knowledge base to be incorporated into LLM responses.
8
+ RAG involves creating a knowledge base containing two types of knowledge: parametric knowledge from LLM training and source knowledge from external input. Data for the knowledge base is derived from datasets relevant to the use case, which are then processed into smaller chunks to enhance relevance and efficiency. Token embeddings, generated using models like RoBERTa, are crucial for retrieving context and meaning from the knowledge base.
9
+ A vector database could be used to manage and search through the embeddings efficiently. The LangChain library facilitates interactions with the knowledge base, allowing LLMs to generate responses based on retrieved information. Generative Question Answering (GQA) or Retrieval Augmented Generation (RAG) techniques instruct the LLM to craft answers using knowledge base content. To enhance trust, answers can be accompanied by citations indicating the information source.
10
+ RAG leverages a combination of external knowledge and LLM capabilities to provide accurate, up-to-date, and well-grounded responses. This approach is gaining traction in products such as AI search engines and conversational agents, highlighting the synergy between LLMs and robust knowledge bases.
11
+ ''')
12
+ with gr.Row():
13
+
14
+ with gr.Column(scale=0.5, variant = 'panel'):
15
+ gr.Markdown("## Upload Document & Select the Embedding Model")
16
+ file = gr.File(type="file")
17
+ with gr.Row(equal_height=True):
18
+
19
+ with gr.Column(scale=0.5, variant = 'panel'):
20
+ embedding_model = gr.Dropdown(choices= ["all-roberta-large-v1_1024d", "all-mpnet-base-v2_768d"],
21
+ value="all-roberta-large-v1_1024d",
22
+ label= "Select the embedding model")
23
+
24
+ with gr.Column(scale=0.5, variant='compact'):
25
+ vector_index_btn = gr.Button('Create vector store', variant='primary',scale=1)
26
+ vector_index_msg_out = gr.Textbox(show_label=False, lines=1,scale=1, placeholder="Creating vectore store ...")
27
+
28
+ instruction = gr.Textbox(label="System instruction", lines=3, value="Use the following pieces of context to answer the question at the end by. Generate the answer based on the given context only.If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive.")
29
+ reset_inst_btn = gr.Button('Reset',variant='primary', size = 'sm')
30
+
31
+ with gr.Accordion(label="Text generation tuning parameters"):
32
+ temperature = gr.Slider(label="temperature", minimum=0.1, maximum=1, value=0.1, step=0.05)
33
+ max_new_tokens = gr.Slider(label="max_new_tokens", minimum=1, maximum=2048, value=512, step=1)
34
+ repetition_penalty = gr.Slider(label="repetition_penalty", minimum=0, maximum=2, value=1.1, step=0.1)
35
+ top_k= gr.Slider(label="top_k", minimum=1, maximum=1000, value=10, step=1)
36
+ top_p=gr.Slider(label="top_p", minimum=0, maximum=1, value=0.95, step=0.05)
37
+ k_context=gr.Slider(label="k_context", minimum=1, maximum=15, value=5, step=1)
38
+
39
+ vector_index_btn.click(upload_and_create_vector_store,[file,embedding_model],vector_index_msg_out)
40
+ reset_inst_btn.click(reset_sys_instruction,instruction,instruction)
41
+
42
+ with gr.Column(scale=0.5, variant = 'panel'):
43
+ gr.Markdown("## Select the Generation Model")
44
+
45
+ with gr.Row(equal_height=True):
46
+
47
+ with gr.Column(scale=0.5):
48
+ llm = gr.Dropdown(choices= ["Llamav2-7B-Chat", "Falcon-7B-Instruct"], value="Llamav2-7B-Chat", label="Select the LLM")
49
+ hf_token = gr.Textbox(label='Enter your valid HF token_id', type = "password")
50
+
51
+ with gr.Column(scale=0.5):
52
+ model_load_btn = gr.Button('Load model', variant='primary',scale=2)
53
+ load_success_msg = gr.Textbox(show_label=False,lines=1, placeholder="Model loading ...")
54
+ chatbot = gr.Chatbot([], elem_id="chatbot",
55
+ label='Chatbox', height=725, )
56
+
57
+ txt = gr.Textbox(label= "Question",lines=2,placeholder="Enter your question and press shift+enter ")
58
+
59
+ with gr.Row():
60
+
61
+ with gr.Column(scale=0.5):
62
+ submit_btn = gr.Button('Submit',variant='primary', size = 'sm')
63
+
64
+ with gr.Column(scale=0.5):
65
+ clear_btn = gr.Button('Clear',variant='stop',size = 'sm')
66
+
67
+ model_load_btn.click(load_models, [hf_token,embedding_model,llm], load_success_msg, api_name="load_models")
68
+
69
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
70
+ bot, [chatbot,instruction,temperature,max_new_tokens,repetition_penalty,top_k,top_p,k_context], chatbot)
71
+ submit_btn.click(add_text, [chatbot, txt], [chatbot, txt]).then(
72
+ bot, [chatbot,instruction,temperature, max_new_tokens,repetition_penalty,top_k,top_p,k_context], chatbot).then(
73
+ clear_cuda_cache, None, None
74
+ )
75
+
76
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
77
+
78
+
79
+ if __name__ == '__main__':
80
+ demo.queue(concurrency_count=3)
81
+ demo.launch(debug=True, share=True)
src/model_setup.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,gc,shutil
2
+ from util.conversation_rag import Conversation_RAG
3
+ from util.index import *
4
+ import torch
5
+
6
+ class ModelSetup:
7
+ def __init__(self, hf_token, embedding_model, llm):
8
+
9
+ self.hf_token = hf_token
10
+ self.embedding_model = embedding_model
11
+ self.llm = llm
12
+
13
+ def setup(self):
14
+
15
+ if self.embedding_model == "all-roberta-large-v1_1024d":
16
+ embedding_model_repo_id = "sentence-transformers/all-roberta-large-v1"
17
+ elif self.embedding_model == "all-mpnet-base-v2_768d":
18
+ embedding_model_repo_id = "sentence-transformers/all-mpnet-base-v2"
19
+
20
+
21
+ if self.llm == "Llamav2-7B-Chat":
22
+ llm_repo_id = "meta-llama/Llama-2-7b-chat-hf"
23
+ elif self.llm == "Falcon-7B-Instruct":
24
+ llm_repo_id = "tiiuae/falcon-7b-instruct"
25
+
26
+
27
+ conv_rag = Conversation_RAG(self.hf_token,
28
+ embedding_model_repo_id,
29
+ llm_repo_id)
30
+
31
+ self.model, self.tokenizer, self.vectordb = conv_rag.load_model_and_tokenizer()
32
+ return "Model Setup Complete"
src/model_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,gc,shutil
2
+ import gradio as gr
3
+ from util.conversation_rag import Conversation_RAG
4
+ from util.index import *
5
+ import torch
6
+ from model_setup import ModelSetup
7
+
8
+
9
+ def load_models(hf_token,embedding_model,llm):
10
+
11
+ global model_setup
12
+ model_setup = ModelSetup(hf_token, embedding_model, llm)
13
+ success_prompt = model_setup.setup()
14
+ return success_prompt
15
+
16
+
17
+ def upload_and_create_vector_store(file,embedding_model):
18
+
19
+ # Save the uploaded file to a permanent location
20
+ file_path = file.name
21
+ split_file_name = file_path.split("/")
22
+ file_name = split_file_name[-1]
23
+
24
+ current_folder = os.path.dirname(os.path.abspath(__file__))
25
+ root_folder = os.path.dirname(current_folder)
26
+ data_folder = os.path.join(root_folder, "data")
27
+ permanent_file_path = os.path.join(data_folder, file_name)
28
+ shutil.copy(file.name, permanent_file_path)
29
+
30
+ # Access the path of the saved file
31
+ print(f"File saved to: {permanent_file_path}")
32
+
33
+ if embedding_model == "all-roberta-large-v1_1024d":
34
+ embedding_model_repo_id = "sentence-transformers/all-roberta-large-v1"
35
+ elif embedding_model == "all-mpnet-base-v2_768d":
36
+ embedding_model_repo_id = "sentence-transformers/all-mpnet-base-v2"
37
+
38
+ index_success_msg = create_vector_store_index(permanent_file_path,embedding_model_repo_id)
39
+ return index_success_msg
40
+
41
+ def get_chat_history(inputs):
42
+
43
+ res = []
44
+ for human, ai in inputs:
45
+ res.append(f"Human:{human}\nAssistant:{ai}")
46
+ return "\n".join(res)
47
+
48
+ def add_text(history, text):
49
+
50
+ history = history + [[text, None]]
51
+ return history, ""
52
+
53
+ conv_qa = Conversation_RAG()
54
+ def bot(history,
55
+ instruction="Use the following pieces of context to answer the question at the end. Generate the answer based on the given context only if you find the answer in the context. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive.",
56
+ temperature=0.1,
57
+ max_new_tokens=512,
58
+ repetition_penalty=1.1,
59
+ top_k=10,
60
+ top_p=0.95,
61
+ k_context=5,
62
+ num_return_sequences=1,
63
+ ):
64
+
65
+ qa = conv_qa.create_conversation(model_setup.model,
66
+ model_setup.tokenizer,
67
+ model_setup.vectordb,
68
+ max_new_tokens=max_new_tokens,
69
+ temperature=temperature,
70
+ repetition_penalty=repetition_penalty,
71
+ top_k=top_k,
72
+ top_p=top_p,
73
+ k_context=k_context,
74
+ num_return_sequences=num_return_sequences,
75
+ instruction=instruction
76
+
77
+ )
78
+
79
+ chat_history_formatted = get_chat_history(history[:-1])
80
+ res = qa(
81
+ {
82
+ 'question': history[-1][0],
83
+ 'chat_history': chat_history_formatted
84
+ }
85
+ )
86
+
87
+ history[-1][1] = res['answer']
88
+ return history
89
+
90
+ def reset_sys_instruction(instruction):
91
+
92
+ default_inst = "Use the following pieces of context to answer the question at the end. Generate the answer based on the given context only if you find the answer in the context. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive."
93
+ return default_inst
94
+
95
+ def clear_cuda_cache():
96
+
97
+ torch.cuda.empty_cache()
98
+ gc.collect()
99
+ return None
src/util/conversation_rag.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import cuda, bfloat16
2
+ import transformers
3
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+ from langchain.chains import ConversationalRetrievalChain
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.llms import HuggingFacePipeline
8
+ from huggingface_hub import login
9
+ from langchain.prompts import PromptTemplate
10
+
11
+
12
+ class Conversation_RAG:
13
+ def __init__(self, hf_token = "", embedding_model_repo_id="sentence-transformers/all-roberta-large-v1",
14
+ llm_repo_id='meta-llama/Llama-2-7b-chat-hf'):
15
+
16
+ self.hf_token = hf_token
17
+ self.embedding_model_repo_id = embedding_model_repo_id
18
+ self.llm_repo_id = llm_repo_id
19
+
20
+ def load_model_and_tokenizer(self):
21
+
22
+ embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_repo_id)
23
+ vectordb = FAISS.load_local("./db/faiss_index", embedding_model)
24
+
25
+ login(token=self.hf_token)
26
+
27
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
28
+
29
+ bnb_config = transformers.BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ bnb_4bit_quant_type='nf4',
32
+ bnb_4bit_use_double_quant=True,
33
+ bnb_4bit_compute_dtype=bfloat16
34
+ )
35
+
36
+ model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ self.llm_repo_id,
38
+ trust_remote_code=True,
39
+ quantization_config=bnb_config,
40
+ load_in_8bit=True,
41
+ device_map='auto'
42
+ )
43
+ model.eval()
44
+
45
+ tokenizer = transformers.AutoTokenizer.from_pretrained(self.llm_repo_id)
46
+ return model, tokenizer, vectordb
47
+
48
+ def create_conversation(self, model, tokenizer, vectordb, max_new_tokens=512, temperature=0.1, repetition_penalty=1.1, top_k=10, top_p=0.95, k_context=5,
49
+ num_return_sequences=1, instruction="Use the following pieces of context to answer the question at the end by. Generate the answer based on the given context only. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive."):
50
+
51
+ generate_text = transformers.pipeline(
52
+ model=model,
53
+ tokenizer=tokenizer,
54
+ return_full_text=True, # langchain expects the full text
55
+ task='text-generation',
56
+ temperature=temperature, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
57
+ max_new_tokens=max_new_tokens, # mex number of tokens to generate in the output
58
+ repetition_penalty=repetition_penalty, # without this output begins repeating
59
+ top_k=top_k,
60
+ top_p=top_p,
61
+ num_return_sequences=num_return_sequences,
62
+ )
63
+
64
+ llm = HuggingFacePipeline(pipeline=generate_text)
65
+
66
+ system_instruction = f"User: {instruction}\n"
67
+ template = system_instruction + """
68
+ context:\n
69
+ {context}\n
70
+ Question: {question}\n
71
+ Assistant:
72
+ """
73
+
74
+ QCA_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template)
75
+
76
+ qa = ConversationalRetrievalChain.from_llm(
77
+ llm=llm,
78
+ chain_type='stuff',
79
+ retriever=vectordb.as_retriever(search_kwargs={"k": k_context}),
80
+ combine_docs_chain_kwargs={"prompt": QCA_PROMPT},
81
+ get_chat_history=lambda h: h,
82
+ verbose=True
83
+ )
84
+ return qa
85
+
src/util/index.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pickle
4
+
5
+ from langchain.vectorstores import FAISS, Chroma, DocArrayInMemorySearch
6
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
7
+ from langchain.document_loaders.csv_loader import CSVLoader
8
+ from langchain.text_splitter import CharacterTextSplitter
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+
12
+ def create_vector_store_index(file_path, embedding_model_repo_id="sentence-transformers/all-roberta-large-v1"):
13
+
14
+ file_path_split = file_path.split(".")
15
+ file_type = file_path_split[-1].rstrip('/')
16
+
17
+ if file_type == 'csv':
18
+ print(file_path)
19
+ loader = CSVLoader(file_path=file_path)
20
+ documents = loader.load()
21
+
22
+ elif file_type == 'pdf':
23
+ loader = PyPDFLoader(file_path)
24
+ pages = loader.load()
25
+
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size = 1024,
28
+ chunk_overlap = 128,)
29
+
30
+ documents = text_splitter.split_documents(pages)
31
+
32
+
33
+ embedding_model = HuggingFaceEmbeddings(
34
+ model_name=embedding_model_repo_id
35
+ )
36
+
37
+ vectordb = FAISS.from_documents(documents, embedding_model)
38
+ file_output = "./db/faiss_index"
39
+ vectordb.save_local(file_output)
40
+
41
+ return "Vector store index is created."