Update app.py
Browse files
app.py
CHANGED
@@ -2,22 +2,18 @@ import os
|
|
2 |
import glob
|
3 |
import textwrap
|
4 |
import time
|
5 |
-
|
6 |
import langchain
|
7 |
import locale
|
8 |
-
|
9 |
-
|
10 |
import gradio as gr
|
|
|
11 |
locale.getpreferredencoding = lambda: "UTF-8"
|
12 |
-
### loaders
|
13 |
-
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
14 |
|
15 |
-
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
-
|
18 |
-
### prompts
|
19 |
from langchain import PromptTemplate, LLMChain
|
20 |
|
|
|
|
|
21 |
### vector stores
|
22 |
from langchain.vectorstores import FAISS
|
23 |
|
@@ -38,27 +34,52 @@ from transformers import (
|
|
38 |
|
39 |
sorted(glob.glob('/content/anatomy_vol_*'))
|
40 |
|
41 |
-
class CFG:
|
42 |
-
# LLMs
|
43 |
-
model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
|
44 |
-
temperature = 0
|
45 |
-
top_p = 0.95
|
46 |
-
repetition_penalty = 1.15
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
|
55 |
-
#
|
56 |
-
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def get_model(model = CFG.model_name):
|
64 |
|
@@ -251,6 +272,30 @@ def get_model(model = CFG.model_name):
|
|
251 |
|
252 |
return tokenizer, model, max_len
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
tokenizer, model, max_len = get_model(model = CFG.model_name)
|
255 |
|
256 |
pipe = pipeline(
|
@@ -285,8 +330,6 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|
285 |
|
286 |
texts = text_splitter.split_documents(documents)
|
287 |
|
288 |
-
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
289 |
-
|
290 |
vectordb = FAISS.from_documents(
|
291 |
texts,
|
292 |
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
|
@@ -323,56 +366,10 @@ qa_chain = RetrievalQA.from_chain_type(
|
|
323 |
verbose = False
|
324 |
)
|
325 |
|
326 |
-
def wrap_text_preserve_newlines(text, width=700):
|
327 |
-
# Split the input text into lines based on newline characters
|
328 |
-
lines = text.split('\n')
|
329 |
-
|
330 |
-
# Wrap each line individually
|
331 |
-
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
|
332 |
-
|
333 |
-
# Join the wrapped lines back together using newline characters
|
334 |
-
wrapped_text = '\n'.join(wrapped_lines)
|
335 |
-
|
336 |
-
return wrapped_text
|
337 |
-
|
338 |
-
|
339 |
-
def process_llm_response(llm_response):
|
340 |
-
ans = wrap_text_preserve_newlines(llm_response['result'])
|
341 |
-
|
342 |
-
sources_used = ' \n'.join(
|
343 |
-
[
|
344 |
-
source.metadata['source'].split('/')[-1][:-4]
|
345 |
-
+ ' - page: '
|
346 |
-
+ str(source.metadata['page'])
|
347 |
-
for source in llm_response['source_documents']
|
348 |
-
]
|
349 |
-
)
|
350 |
-
|
351 |
-
ans = ans + '\n\nSources: \n' + sources_used
|
352 |
-
return ans
|
353 |
-
|
354 |
-
def llm_ans(query):
|
355 |
-
start = time.time()
|
356 |
-
|
357 |
-
llm_response = qa_chain.invoke(query)
|
358 |
-
ans = process_llm_response(llm_response)
|
359 |
-
|
360 |
-
end = time.time()
|
361 |
-
|
362 |
-
time_elapsed = int(round(end - start, 0))
|
363 |
-
time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
|
364 |
-
return ans + time_elapsed_str
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
def predict(message, history):
|
369 |
-
output = str(llm_ans(message)).replace("\n", "<br/>")
|
370 |
-
return output
|
371 |
-
|
372 |
demo = gr.ChatInterface(
|
373 |
predict,
|
374 |
title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
|
375 |
)
|
376 |
|
377 |
demo.queue()
|
378 |
-
demo.launch()
|
|
|
2 |
import glob
|
3 |
import textwrap
|
4 |
import time
|
|
|
5 |
import langchain
|
6 |
import locale
|
|
|
|
|
7 |
import gradio as gr
|
8 |
+
|
9 |
locale.getpreferredencoding = lambda: "UTF-8"
|
|
|
|
|
10 |
|
11 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
12 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
13 |
from langchain import PromptTemplate, LLMChain
|
14 |
|
15 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
16 |
+
|
17 |
### vector stores
|
18 |
from langchain.vectorstores import FAISS
|
19 |
|
|
|
34 |
|
35 |
sorted(glob.glob('/content/anatomy_vol_*'))
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
def wrap_text_preserve_newlines(text, width=700):
|
39 |
+
# Split the input text into lines based on newline characters
|
40 |
+
lines = text.split('\n')
|
41 |
|
42 |
+
# Wrap each line individually
|
43 |
+
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
|
44 |
|
45 |
+
# Join the wrapped lines back together using newline characters
|
46 |
+
wrapped_text = '\n'.join(wrapped_lines)
|
47 |
+
|
48 |
+
return wrapped_text
|
49 |
|
50 |
+
|
51 |
+
def process_llm_response(llm_response):
|
52 |
+
ans = wrap_text_preserve_newlines(llm_response['result'])
|
53 |
+
|
54 |
+
sources_used = ' \n'.join(
|
55 |
+
[
|
56 |
+
source.metadata['source'].split('/')[-1][:-4]
|
57 |
+
+ ' - page: '
|
58 |
+
+ str(source.metadata['page'])
|
59 |
+
for source in llm_response['source_documents']
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
ans = ans + '\n\nSources: \n' + sources_used
|
64 |
+
return ans
|
65 |
+
|
66 |
+
def llm_ans(query):
|
67 |
+
start = time.time()
|
68 |
+
|
69 |
+
llm_response = qa_chain.invoke(query)
|
70 |
+
ans = process_llm_response(llm_response)
|
71 |
+
|
72 |
+
end = time.time()
|
73 |
+
|
74 |
+
time_elapsed = int(round(end - start, 0))
|
75 |
+
time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
|
76 |
+
return ans + time_elapsed_str
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def predict(message, history):
|
81 |
+
output = str(llm_ans(message)).replace("\n", "<br/>")
|
82 |
+
return output
|
83 |
|
84 |
def get_model(model = CFG.model_name):
|
85 |
|
|
|
272 |
|
273 |
return tokenizer, model, max_len
|
274 |
|
275 |
+
|
276 |
+
class CFG:
|
277 |
+
# LLMs
|
278 |
+
model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
|
279 |
+
temperature = 0
|
280 |
+
top_p = 0.95
|
281 |
+
repetition_penalty = 1.15
|
282 |
+
|
283 |
+
# splitting
|
284 |
+
split_chunk_size = 800
|
285 |
+
split_overlap = 0
|
286 |
+
|
287 |
+
# embeddings
|
288 |
+
embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
|
289 |
+
|
290 |
+
# similar passages
|
291 |
+
k = 6
|
292 |
+
|
293 |
+
# paths
|
294 |
+
PDFs_path = '/content/'
|
295 |
+
Embeddings_path = '/content/faiss-hp-sentence-transformers'
|
296 |
+
Output_folder = './rag-vectordb'
|
297 |
+
|
298 |
+
|
299 |
tokenizer, model, max_len = get_model(model = CFG.model_name)
|
300 |
|
301 |
pipe = pipeline(
|
|
|
330 |
|
331 |
texts = text_splitter.split_documents(documents)
|
332 |
|
|
|
|
|
333 |
vectordb = FAISS.from_documents(
|
334 |
texts,
|
335 |
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
|
|
|
366 |
verbose = False
|
367 |
)
|
368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
demo = gr.ChatInterface(
|
370 |
predict,
|
371 |
title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
|
372 |
)
|
373 |
|
374 |
demo.queue()
|
375 |
+
demo.launch()
|