Ubai commited on
Commit
0abb90d
·
verified ·
1 Parent(s): fc1e558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -43
app.py CHANGED
@@ -14,16 +14,7 @@ from langchain.llms import HuggingFaceHub
14
  from pathlib import Path
15
  import chromadb
16
 
17
- from transformers import AutoTokenizer
18
- import transformers
19
- import torch
20
- import tqdm
21
- import accelerate
22
-
23
- # Update list of LLM models
24
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
25
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
26
-
27
  def load_doc(list_file_path, chunk_size, chunk_overlap):
28
  loaders = [PyPDFLoader(x) for x in list_file_path]
29
  pages = []
@@ -35,6 +26,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
35
  doc_splits = text_splitter.split_documents(pages)
36
  return doc_splits
37
 
 
38
  def create_db(splits, collection_name):
39
  embedding = HuggingFaceEmbeddings()
40
  new_client = chromadb.EphemeralClient()
@@ -46,6 +38,7 @@ def create_db(splits, collection_name):
46
  )
47
  return vectordb
48
 
 
49
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
50
  llm = HuggingFaceHub(
51
  repo_id=llm_model,
@@ -60,7 +53,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
60
  qa_chain = ConversationalRetrievalChain.from_llm(
61
  llm,
62
  retriever=retriever,
63
- chain_type="stuff",
64
  memory=memory,
65
  return_source_documents=True,
66
  verbose=False,
@@ -68,18 +61,20 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
68
  progress(0.9, desc="Done!")
69
  return qa_chain
70
 
71
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, llm_temperature, max_tokens, top_k, progress=gr.Progress()):
 
72
  list_file_path = [x.name for x in list_file_obj if x is not None]
73
  collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
74
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
75
  vector_db = create_db(doc_splits, collection_name)
76
  qa_chain = initialize_llmchain(
77
- list_llm[0],
78
- llm_temperature,
79
- max_tokens,
80
- top_k,
81
- vector_db,
82
- progress)
 
83
  return vector_db, collection_name, qa_chain, "Complete!"
84
 
85
  def format_chat_history(message, chat_history):
@@ -105,13 +100,6 @@ def conversation(qa_chain, message, history):
105
  new_history = history + [(message, response_answer)]
106
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
107
 
108
- def upload_file(file_obj):
109
- list_file_path = []
110
- for idx, file in enumerate(file_obj):
111
- file_path = file_obj.name
112
- list_file_path.append(file_path)
113
- return list_file_path
114
-
115
  def demo():
116
  with gr.Blocks(theme="base") as demo:
117
  vector_db = gr.State()
@@ -125,21 +113,25 @@ def demo():
125
  When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
126
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
127
  """)
128
- with gr.Tab("Chatbot"):
129
- with gr.Row():
130
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
131
- db_btn = gr.Button("Generate vector database...")
132
- with gr.Accordion("Advanced options - Document text splitter", open=False):
133
- with gr.Row():
134
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
135
- with gr.Row():
136
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
137
- with gr.Row():
138
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
139
- with gr.Row():
140
- llm_btn = gr.Radio(list_llm_simple, label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
141
- with gr.Accordion("Advanced options - LLM model", open=False):
142
- with gr.Row():
143
- slider_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
144
- with gr.Row():
145
- slider_maxtokens = gr
 
 
 
 
 
14
  from pathlib import Path
15
  import chromadb
16
 
17
+ # Load PDF document and create doc splits
 
 
 
 
 
 
 
 
 
18
  def load_doc(list_file_path, chunk_size, chunk_overlap):
19
  loaders = [PyPDFLoader(x) for x in list_file_path]
20
  pages = []
 
26
  doc_splits = text_splitter.split_documents(pages)
27
  return doc_splits
28
 
29
+ # Create vector database
30
  def create_db(splits, collection_name):
31
  embedding = HuggingFaceEmbeddings()
32
  new_client = chromadb.EphemeralClient()
 
38
  )
39
  return vectordb
40
 
41
+ # Initialize langchain LLM chain
42
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
43
  llm = HuggingFaceHub(
44
  repo_id=llm_model,
 
53
  qa_chain = ConversationalRetrievalChain.from_llm(
54
  llm,
55
  retriever=retriever,
56
+ chain_type="stuff",
57
  memory=memory,
58
  return_source_documents=True,
59
  verbose=False,
 
61
  progress(0.9, desc="Done!")
62
  return qa_chain
63
 
64
+ # Initialize database and LLM chain
65
+ def initialize_demo(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
66
  list_file_path = [x.name for x in list_file_obj if x is not None]
67
  collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
68
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
69
  vector_db = create_db(doc_splits, collection_name)
70
  qa_chain = initialize_llmchain(
71
+ "mistralai/Mistral-7B-Instruct-v0.2",
72
+ 0.7,
73
+ 1024,
74
+ 3,
75
+ vector_db,
76
+ progress
77
+ )
78
  return vector_db, collection_name, qa_chain, "Complete!"
79
 
80
  def format_chat_history(message, chat_history):
 
100
  new_history = history + [(message, response_answer)]
101
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
102
 
 
 
 
 
 
 
 
103
  def demo():
104
  with gr.Blocks(theme="base") as demo:
105
  vector_db = gr.State()
 
113
  When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
114
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
115
  """)
116
+
117
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
118
+ slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
119
+ slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
120
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
121
+
122
+ # Initialize vector database and LLM chain in the background
123
+ vector_db, collection_name, qa_chain, status = initialize_demo([document], slider_chunk_size, slider_chunk_overlap, db_progress)
124
+
125
+ chatbot = gr.Chatbot(height=300)
126
+ msg = gr.Textbox(placeholder="Type message", container=True)
127
+ submit_btn = gr.Button("Submit")
128
+ clear_btn = gr.ClearButton([msg, chatbot])
129
+
130
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
131
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
132
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot], queue=False)
133
+
134
+ demo.queue().launch(debug=True)
135
+
136
+ if __name__ == "__main__":
137
+ demo()