Spaces:
Runtime error
Runtime error
Vincent Claes
commited on
Commit
·
13b16b6
1
Parent(s):
611aebd
first working end to end version
Browse files- app.py +62 -10
- import_data.py +18 -11
- requirements.txt +117 -0
app.py
CHANGED
@@ -2,32 +2,84 @@ import os
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import weaviate
|
|
|
|
|
|
|
|
|
5 |
|
6 |
collection_name = "Chunk"
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
client = weaviate.Client(
|
11 |
url=os.environ["WEAVIATE_URL"],
|
12 |
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
|
13 |
-
additional_headers={
|
14 |
-
"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]
|
15 |
-
}
|
16 |
)
|
17 |
|
18 |
-
|
19 |
-
client.query
|
20 |
-
.
|
21 |
-
.
|
22 |
-
.with_limit(
|
23 |
-
.with_generate(single_prompt="{text}")
|
24 |
.do()
|
25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
iface = gr.Interface(
|
28 |
fn=predict, # the function to wrap
|
29 |
inputs="text", # the input type
|
30 |
outputs="text", # the output type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
)
|
32 |
|
33 |
if __name__ == "__main__":
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import weaviate
|
5 |
+
from langchain import LLMChain
|
6 |
+
from langchain.chains import SequentialChain
|
7 |
+
from langchain.chat_models import ChatOpenAI
|
8 |
+
from langchain.prompts import ChatPromptTemplate
|
9 |
|
10 |
collection_name = "Chunk"
|
11 |
|
12 |
+
MODEL = "gpt-3.5-turbo"
|
13 |
+
LANGUAGE = "en" # nl / en
|
14 |
+
llm = ChatOpenAI(temperature=0.0, openai_api_key=os.environ["OPENAI_API_KEY"])
|
15 |
|
16 |
+
|
17 |
+
def get_answer_given_the_context(llm, prompt, context) -> SequentialChain:
|
18 |
+
template = f"""
|
19 |
+
Provide an answer to the prompt given the context.
|
20 |
+
|
21 |
+
<PROMPT>
|
22 |
+
|
23 |
+
{prompt}
|
24 |
+
|
25 |
+
<CONTEXT>
|
26 |
+
|
27 |
+
{context}
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
prompt_get_skills_intersection = ChatPromptTemplate.from_template(template=template)
|
32 |
+
skills_match_chain = LLMChain(
|
33 |
+
llm=llm,
|
34 |
+
prompt=prompt_get_skills_intersection,
|
35 |
+
output_key="answer",
|
36 |
+
)
|
37 |
+
|
38 |
+
chain = SequentialChain(
|
39 |
+
chains=[skills_match_chain],
|
40 |
+
input_variables=["prompt", "context"],
|
41 |
+
output_variables=[
|
42 |
+
skills_match_chain.output_key,
|
43 |
+
],
|
44 |
+
verbose=False,
|
45 |
+
)
|
46 |
+
return chain({"prompt": prompt, "context": context})["answer"]
|
47 |
+
|
48 |
+
|
49 |
+
def predict(prompt):
|
50 |
client = weaviate.Client(
|
51 |
url=os.environ["WEAVIATE_URL"],
|
52 |
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
|
53 |
+
additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
|
|
|
|
|
54 |
)
|
55 |
|
56 |
+
search_result = (
|
57 |
+
client.query.get(class_name=collection_name, properties=["text"])
|
58 |
+
.with_near_text({"concepts": prompt})
|
59 |
+
# .with_generate(single_prompt="{text}")
|
60 |
+
.with_limit(5)
|
|
|
61 |
.do()
|
62 |
)
|
63 |
+
context_list = [
|
64 |
+
element["text"] for element in search_result["data"]["Get"]["Chunk"]
|
65 |
+
]
|
66 |
+
context = "\n".join(context_list)
|
67 |
+
|
68 |
+
return get_answer_given_the_context(llm=llm, prompt=prompt, context=context)
|
69 |
+
|
70 |
|
71 |
iface = gr.Interface(
|
72 |
fn=predict, # the function to wrap
|
73 |
inputs="text", # the input type
|
74 |
outputs="text", # the output type
|
75 |
+
examples=[
|
76 |
+
[f"what is the process of raising an incident?"],
|
77 |
+
[f"What is Cx0 program management?"],
|
78 |
+
[
|
79 |
+
f"What is process for identifying risksthat can impact the desired outcomes of a project?"
|
80 |
+
],
|
81 |
+
[f"What is the release management process?"],
|
82 |
+
],
|
83 |
)
|
84 |
|
85 |
if __name__ == "__main__":
|
import_data.py
CHANGED
@@ -6,6 +6,7 @@ from llama_index import VectorStoreIndex, StorageContext
|
|
6 |
from pathlib import Path
|
7 |
import argparse
|
8 |
|
|
|
9 |
def get_pdf_files(base_path, loader):
|
10 |
"""
|
11 |
Get paths to all PDF files in a directory and its subdirectories.
|
@@ -22,13 +23,15 @@ def get_pdf_files(base_path, loader):
|
|
22 |
if not os.path.exists(base_path):
|
23 |
raise FileNotFoundError(f"The specified base path does not exist: {base_path}")
|
24 |
if not os.path.isdir(base_path):
|
25 |
-
raise NotADirectoryError(
|
|
|
|
|
26 |
|
27 |
# Loop through all directories and files starting from the base path
|
28 |
for root, dirs, files in os.walk(base_path):
|
29 |
for filename in files:
|
30 |
# If a file has a .pdf extension, add its path to the list
|
31 |
-
if filename.endswith(
|
32 |
pdf_file = loader.load_data(file=Path(root, filename))
|
33 |
pdf_paths.extend(pdf_file)
|
34 |
|
@@ -44,13 +47,13 @@ def main(args):
|
|
44 |
client = weaviate.Client(
|
45 |
url=os.environ["WEAVIATE_URL"],
|
46 |
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
|
47 |
-
additional_headers={
|
48 |
-
"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]
|
49 |
-
}
|
50 |
)
|
51 |
|
52 |
# construct vector store
|
53 |
-
vector_store = WeaviateVectorStore(
|
|
|
|
|
54 |
|
55 |
# setting up the storage for the embeddings
|
56 |
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
@@ -63,11 +66,15 @@ def main(args):
|
|
63 |
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
-
parser = argparse.ArgumentParser(description=
|
67 |
-
|
68 |
-
parser.add_argument(
|
69 |
-
parser.add_argument(
|
70 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
71 |
|
72 |
args = parser.parse_args()
|
73 |
|
|
|
6 |
from pathlib import Path
|
7 |
import argparse
|
8 |
|
9 |
+
|
10 |
def get_pdf_files(base_path, loader):
|
11 |
"""
|
12 |
Get paths to all PDF files in a directory and its subdirectories.
|
|
|
23 |
if not os.path.exists(base_path):
|
24 |
raise FileNotFoundError(f"The specified base path does not exist: {base_path}")
|
25 |
if not os.path.isdir(base_path):
|
26 |
+
raise NotADirectoryError(
|
27 |
+
f"The specified base_path is not a directory: {base_path}"
|
28 |
+
)
|
29 |
|
30 |
# Loop through all directories and files starting from the base path
|
31 |
for root, dirs, files in os.walk(base_path):
|
32 |
for filename in files:
|
33 |
# If a file has a .pdf extension, add its path to the list
|
34 |
+
if filename.endswith(".pdf"):
|
35 |
pdf_file = loader.load_data(file=Path(root, filename))
|
36 |
pdf_paths.extend(pdf_file)
|
37 |
|
|
|
47 |
client = weaviate.Client(
|
48 |
url=os.environ["WEAVIATE_URL"],
|
49 |
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
|
50 |
+
additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
|
|
|
|
|
51 |
)
|
52 |
|
53 |
# construct vector store
|
54 |
+
vector_store = WeaviateVectorStore(
|
55 |
+
weaviate_client=client, index_name=args.customer, text_key="content"
|
56 |
+
)
|
57 |
|
58 |
# setting up the storage for the embeddings
|
59 |
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
|
66 |
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
+
parser = argparse.ArgumentParser(description="Process and query PDF files.")
|
70 |
+
|
71 |
+
parser.add_argument("--customer", default="Ausy", help="Customer name")
|
72 |
+
parser.add_argument("--pdf_dir", default="./data", help="Directory containing PDFs")
|
73 |
+
parser.add_argument(
|
74 |
+
"--query",
|
75 |
+
default="What is CX0 customer exprience office?",
|
76 |
+
help="Query to execute",
|
77 |
+
)
|
78 |
|
79 |
args = parser.parse_args()
|
80 |
|
requirements.txt
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
2 |
+
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
|
3 |
+
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
4 |
+
altair==5.1.1 ; python_version >= "3.9" and python_version < "4.0"
|
5 |
+
annotated-types==0.5.0 ; python_version >= "3.9" and python_version < "4.0"
|
6 |
+
anyio==3.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
7 |
+
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
8 |
+
attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
9 |
+
authlib==1.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
10 |
+
beautifulsoup4==4.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
11 |
+
blis==0.7.10 ; python_version >= "3.9" and python_version < "4.0"
|
12 |
+
catalogue==2.0.9 ; python_version >= "3.9" and python_version < "4.0"
|
13 |
+
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0"
|
14 |
+
cffi==1.15.1 ; python_version >= "3.9" and python_version < "4.0"
|
15 |
+
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
16 |
+
click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
|
17 |
+
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
|
18 |
+
confection==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
19 |
+
contourpy==1.1.1 ; python_version >= "3.9" and python_version < "4.0"
|
20 |
+
cryptography==41.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
21 |
+
cycler==0.11.0 ; python_version >= "3.9" and python_version < "4.0"
|
22 |
+
cymem==2.0.8 ; python_version >= "3.9" and python_version < "4.0"
|
23 |
+
dataclasses-json==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
24 |
+
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
|
25 |
+
fastapi==0.103.1 ; python_version >= "3.9" and python_version < "4.0"
|
26 |
+
ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
27 |
+
filelock==3.12.4 ; python_version >= "3.9" and python_version < "4.0"
|
28 |
+
fonttools==4.42.1 ; python_version >= "3.9" and python_version < "4.0"
|
29 |
+
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
|
30 |
+
fsspec==2023.9.1 ; python_version >= "3.9" and python_version < "4.0"
|
31 |
+
goldenverba==0.2.3 ; python_version >= "3.9" and python_version < "4.0"
|
32 |
+
gradio-client==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
|
33 |
+
gradio==3.44.4 ; python_version >= "3.9" and python_version < "4.0"
|
34 |
+
greenlet==2.0.2 ; python_version >= "3.9" and python_version < "4.0" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64")
|
35 |
+
h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
|
36 |
+
httpcore==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
|
37 |
+
httptools==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
38 |
+
httpx==0.25.0 ; python_version >= "3.9" and python_version < "4.0"
|
39 |
+
huggingface-hub==0.17.2 ; python_version >= "3.9" and python_version < "4.0"
|
40 |
+
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
41 |
+
importlib-metadata==6.8.0 ; python_version >= "3.9" and python_version < "3.10"
|
42 |
+
importlib-resources==6.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
43 |
+
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
44 |
+
joblib==1.3.2 ; python_version >= "3.9" and python_version < "4.0"
|
45 |
+
jsonschema-specifications==2023.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
46 |
+
jsonschema==4.19.1 ; python_version >= "3.9" and python_version < "4.0"
|
47 |
+
kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
|
48 |
+
langchain==0.0.296 ; python_version >= "3.9" and python_version < "4.0"
|
49 |
+
langcodes==3.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
50 |
+
langsmith==0.0.38 ; python_version >= "3.9" and python_version < "4.0"
|
51 |
+
llama-index==0.8.29.post1 ; python_version >= "3.9" and python_version < "4.0"
|
52 |
+
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
53 |
+
marshmallow==3.20.1 ; python_version >= "3.9" and python_version < "4.0"
|
54 |
+
matplotlib==3.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
55 |
+
multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
56 |
+
murmurhash==1.0.10 ; python_version >= "3.9" and python_version < "4.0"
|
57 |
+
mypy-extensions==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
58 |
+
nest-asyncio==1.5.8 ; python_version >= "3.9" and python_version < "4.0"
|
59 |
+
nltk==3.8.1 ; python_version >= "3.9" and python_version < "4.0"
|
60 |
+
numexpr==2.8.6 ; python_version >= "3.9" and python_version < "4.0"
|
61 |
+
numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0"
|
62 |
+
openai==0.28.0 ; python_version >= "3.9" and python_version < "4.0"
|
63 |
+
orjson==3.9.7 ; python_version >= "3.9" and python_version < "4.0"
|
64 |
+
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
65 |
+
pandas==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
66 |
+
pathy==0.10.2 ; python_version >= "3.9" and python_version < "4.0"
|
67 |
+
pillow==10.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
68 |
+
preshed==3.0.9 ; python_version >= "3.9" and python_version < "4.0"
|
69 |
+
pycparser==2.21 ; python_version >= "3.9" and python_version < "4.0"
|
70 |
+
pydantic-core==2.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
71 |
+
pydantic==2.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
72 |
+
pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
|
73 |
+
pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
|
74 |
+
pypdf==3.16.1 ; python_version >= "3.9" and python_version < "4.0"
|
75 |
+
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
|
76 |
+
python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
77 |
+
python-multipart==0.0.6 ; python_version >= "3.9" and python_version < "4.0"
|
78 |
+
pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "4.0"
|
79 |
+
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
80 |
+
referencing==0.30.2 ; python_version >= "3.9" and python_version < "4.0"
|
81 |
+
regex==2023.8.8 ; python_version >= "3.9" and python_version < "4.0"
|
82 |
+
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
83 |
+
rpds-py==0.10.3 ; python_version >= "3.9" and python_version < "4.0"
|
84 |
+
semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
|
85 |
+
setuptools-scm==8.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
86 |
+
setuptools==68.2.2 ; python_version >= "3.9" and python_version < "4.0"
|
87 |
+
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
|
88 |
+
smart-open==6.4.0 ; python_version >= "3.9" and python_version < "4.0"
|
89 |
+
sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
90 |
+
soupsieve==2.5 ; python_version >= "3.9" and python_version < "4.0"
|
91 |
+
spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "4.0"
|
92 |
+
spacy-loggers==1.0.5 ; python_version >= "3.9" and python_version < "4.0"
|
93 |
+
spacy==3.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
94 |
+
sqlalchemy==2.0.21 ; python_version >= "3.9" and python_version < "4.0"
|
95 |
+
srsly==2.4.7 ; python_version >= "3.9" and python_version < "4.0"
|
96 |
+
starlette==0.27.0 ; python_version >= "3.9" and python_version < "4.0"
|
97 |
+
tenacity==8.2.3 ; python_version >= "3.9" and python_version < "4.0"
|
98 |
+
thinc==8.1.12 ; python_version >= "3.9" and python_version < "4.0"
|
99 |
+
tiktoken==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
|
100 |
+
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
101 |
+
toolz==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
|
102 |
+
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
|
103 |
+
typer==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
|
104 |
+
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
105 |
+
typing-inspect==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
|
106 |
+
tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
|
107 |
+
urllib3==1.26.16 ; python_version >= "3.9" and python_version < "4.0"
|
108 |
+
uvicorn==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
109 |
+
uvicorn[standard]==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
110 |
+
uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "4.0"
|
111 |
+
validators==0.22.0 ; python_version >= "3.9" and python_version < "4.0"
|
112 |
+
wasabi==1.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
113 |
+
watchfiles==0.20.0 ; python_version >= "3.9" and python_version < "4.0"
|
114 |
+
weaviate-client==3.24.1 ; python_version >= "3.9" and python_version < "4.0"
|
115 |
+
websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
116 |
+
yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
|
117 |
+
zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"
|