Upload 3 files
Browse files- MTM_Memoir_txt.txt +0 -0
- app.py +153 -0
- requirements.txt +128 -0
MTM_Memoir_txt.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We are going to develop the code for the RAG here. This is going to be the first and the only attempt IA!!
|
2 |
+
# To create the POC
|
3 |
+
|
4 |
+
# ! We need to do the following,
|
5 |
+
# Convert the PDF to Embeddings and save it into a vector database.
|
6 |
+
# Load LLAMA 2
|
7 |
+
# Connect LLAMA 2 to the vector database.
|
8 |
+
# Ask Questions and give answers.
|
9 |
+
|
10 |
+
|
11 |
+
# ! LLAMA IS LOADED
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
+
from transformers import BitsAndBytesConfig
|
16 |
+
import torch
|
17 |
+
import json
|
18 |
+
|
19 |
+
from torch import cuda
|
20 |
+
import torch
|
21 |
+
import transformers
|
22 |
+
from time import time
|
23 |
+
import chromadb
|
24 |
+
from chromadb.config import Settings
|
25 |
+
from langchain.llms import huggingface_pipeline
|
26 |
+
from langchain.document_loaders import TextLoader
|
27 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
28 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
29 |
+
from langchain.chains import RetrievalQA
|
30 |
+
from langchain.vectorstores.chroma import Chroma
|
31 |
+
|
32 |
+
|
33 |
+
nf4_config = BitsAndBytesConfig(
|
34 |
+
load_in_4bit=True,
|
35 |
+
bnb_4bit_quant_type="nf4",
|
36 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
37 |
+
)
|
38 |
+
|
39 |
+
# Change the model path here to test any other model.
|
40 |
+
# model_path = 'training_date_02_10_2023_psql/final_merged_checkpoint'
|
41 |
+
model_path = 'Llama-13b-chat'
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
44 |
+
model_path,
|
45 |
+
local_files_only=True,
|
46 |
+
)
|
47 |
+
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
49 |
+
model_path,
|
50 |
+
local_files_only=True,
|
51 |
+
low_cpu_mem_usage=True,
|
52 |
+
device_map="auto",
|
53 |
+
offload_folder="offload/",
|
54 |
+
cache_dir="cache/",
|
55 |
+
quantization_config=nf4_config # forgot this on the first try so full model was loaded.
|
56 |
+
)
|
57 |
+
|
58 |
+
model_config = transformers.AutoConfig.from_pretrained(model_path)
|
59 |
+
|
60 |
+
|
61 |
+
# define query huggingface_pipeline
|
62 |
+
query_pipeline = transformers.pipeline(
|
63 |
+
"text-generation",
|
64 |
+
model=model,
|
65 |
+
tokenizer=tokenizer,
|
66 |
+
torch_dtype=torch.float16,
|
67 |
+
device_map="auto"
|
68 |
+
)
|
69 |
+
|
70 |
+
llm = huggingface_pipeline.HuggingFacePipeline(pipeline=query_pipeline)
|
71 |
+
|
72 |
+
|
73 |
+
# Ingestion of data using text loader
|
74 |
+
loader = TextLoader("MTM_Memoir_txt.txt", encoding="utf-8")
|
75 |
+
documents = loader.load()
|
76 |
+
|
77 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
|
78 |
+
all_splits = text_splitter.split_documents(documents)
|
79 |
+
|
80 |
+
# let's create the embeddings and store in vector store
|
81 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
82 |
+
model_kwargs = {"device": "cuda"}
|
83 |
+
|
84 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
85 |
+
|
86 |
+
# initialize chromadb
|
87 |
+
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="chroma_db")
|
88 |
+
|
89 |
+
# Initialize the chain
|
90 |
+
retriever = vectordb.as_retriever()
|
91 |
+
qa = RetrievalQA.from_chain_type(
|
92 |
+
llm=llm,
|
93 |
+
chain_type="stuff",
|
94 |
+
retriever=retriever,
|
95 |
+
verbose=True
|
96 |
+
)
|
97 |
+
|
98 |
+
# let's test the RAG
|
99 |
+
def test_rag(qa, query):
|
100 |
+
print(query)
|
101 |
+
result = qa.run(query)
|
102 |
+
print(f"Result \t {result}")
|
103 |
+
|
104 |
+
test_rag(qa, "Hello when were you born?")
|
105 |
+
|
106 |
+
|
107 |
+
def preprocess_query(query):
|
108 |
+
|
109 |
+
# load the query as a dict
|
110 |
+
# res = json.loads(str(query))
|
111 |
+
|
112 |
+
# human_language = res['human_language']
|
113 |
+
|
114 |
+
# SQL_TABLE_CONTEXT = "CREATE TABLE properties (address character, details characterstate character, property_type character, price integer, bedrooms integer, bathrooms integer, sqft integer)"
|
115 |
+
|
116 |
+
# INTRO = f"<s>[INST] <<SYS>> \
|
117 |
+
# You are a helpful, genius data scientist, who has access to a database that contains listing of properties in New York. Your job is to write PostgreSQL Query to fetch data based on User Request and Parameters. If you can't reply correctly just say that not enough information was provided \n\n<</SYS>>"
|
118 |
+
INTRO = f"<s>[INST] <<SYS>>You are former Malaysian Prime Minister Tun Dr Mahathir Mohamad.. A visionary leader \n\n<</SYS>>"
|
119 |
+
INSTRUCTION = f"### Instruction\n Respond to the following query by your subject {query} Just like yourself. \n\n"
|
120 |
+
|
121 |
+
RESPONSE = f"### Response:\n\n"
|
122 |
+
|
123 |
+
final_payload = INTRO + INSTRUCTION + RESPONSE
|
124 |
+
payload_length = len(final_payload)
|
125 |
+
|
126 |
+
return final_payload
|
127 |
+
|
128 |
+
|
129 |
+
def get_result(qa=qa, query = ""):
|
130 |
+
return qa.run(query)
|
131 |
+
|
132 |
+
|
133 |
+
def predict(query):
|
134 |
+
processed_query = preprocess_query(query=query)
|
135 |
+
result = get_result(query=processed_query)
|
136 |
+
return(result)
|
137 |
+
|
138 |
+
|
139 |
+
# ! The following will also work now! I mistakenly wrote ap_name insted of api_name in the submit_button.click()
|
140 |
+
|
141 |
+
with gr.Blocks() as sql_generator:
|
142 |
+
query = gr.Textbox(label="Query", placeholder='Ask the president?')
|
143 |
+
|
144 |
+
output = gr.Textbox(label="Output")
|
145 |
+
submit_button = gr.Button("Submit")
|
146 |
+
submit_button.click(fn=predict,
|
147 |
+
inputs=query,
|
148 |
+
outputs=output, api_name="predict"
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
sql_generator.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.23.0
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohttp==3.8.6
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==5.1.2
|
6 |
+
annotated-types==0.6.0
|
7 |
+
anyio==3.7.1
|
8 |
+
async-timeout==4.0.3
|
9 |
+
attrs==23.1.0
|
10 |
+
backoff==2.2.1
|
11 |
+
bcrypt==4.0.1
|
12 |
+
bitsandbytes==0.41.1
|
13 |
+
certifi==2023.7.22
|
14 |
+
charset-normalizer==3.3.0
|
15 |
+
chroma-hnswlib==0.7.3
|
16 |
+
chromadb==0.4.14
|
17 |
+
click==8.1.7
|
18 |
+
coloredlogs==15.0.1
|
19 |
+
contourpy==1.1.1
|
20 |
+
cycler==0.12.1
|
21 |
+
dataclasses-json==0.6.1
|
22 |
+
einops==0.7.0
|
23 |
+
exceptiongroup==1.1.3
|
24 |
+
fastapi==0.103.2
|
25 |
+
ffmpy==0.3.1
|
26 |
+
filelock==3.12.4
|
27 |
+
flatbuffers==23.5.26
|
28 |
+
fonttools==4.43.1
|
29 |
+
frozenlist==1.4.0
|
30 |
+
fsspec==2023.9.2
|
31 |
+
gradio==3.47.1
|
32 |
+
gradio_client==0.6.0
|
33 |
+
greenlet==3.0.0
|
34 |
+
grpcio==1.59.0
|
35 |
+
h11==0.14.0
|
36 |
+
httpcore==0.18.0
|
37 |
+
httptools==0.6.0
|
38 |
+
httpx==0.25.0
|
39 |
+
huggingface-hub==0.17.3
|
40 |
+
humanfriendly==10.0
|
41 |
+
idna==3.4
|
42 |
+
importlib-resources==6.1.0
|
43 |
+
Jinja2==3.1.2
|
44 |
+
joblib==1.3.2
|
45 |
+
jsonpatch==1.33
|
46 |
+
jsonpointer==2.4
|
47 |
+
jsonschema==4.19.1
|
48 |
+
jsonschema-specifications==2023.7.1
|
49 |
+
kiwisolver==1.4.5
|
50 |
+
langchain==0.0.315
|
51 |
+
langsmith==0.0.43
|
52 |
+
MarkupSafe==2.1.3
|
53 |
+
marshmallow==3.20.1
|
54 |
+
matplotlib==3.8.0
|
55 |
+
monotonic==1.6
|
56 |
+
mpmath==1.3.0
|
57 |
+
multidict==6.0.4
|
58 |
+
mypy-extensions==1.0.0
|
59 |
+
networkx==3.1
|
60 |
+
nltk==3.8.1
|
61 |
+
numpy==1.26.1
|
62 |
+
nvidia-cublas-cu12==12.1.3.1
|
63 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
64 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
65 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
66 |
+
nvidia-cudnn-cu12==8.9.2.26
|
67 |
+
nvidia-cufft-cu12==11.0.2.54
|
68 |
+
nvidia-curand-cu12==10.3.2.106
|
69 |
+
nvidia-cusolver-cu12==11.4.5.107
|
70 |
+
nvidia-cusparse-cu12==12.1.0.106
|
71 |
+
nvidia-nccl-cu12==2.18.1
|
72 |
+
nvidia-nvjitlink-cu12==12.2.140
|
73 |
+
nvidia-nvtx-cu12==12.1.105
|
74 |
+
onnxruntime==1.16.1
|
75 |
+
orjson==3.9.9
|
76 |
+
overrides==7.4.0
|
77 |
+
packaging==23.2
|
78 |
+
pandas==2.1.1
|
79 |
+
Pillow==10.1.0
|
80 |
+
posthog==3.0.2
|
81 |
+
protobuf==4.24.4
|
82 |
+
psutil==5.9.6
|
83 |
+
pulsar-client==3.3.0
|
84 |
+
pydantic==2.4.2
|
85 |
+
pydantic_core==2.10.1
|
86 |
+
pydub==0.25.1
|
87 |
+
pyparsing==3.1.1
|
88 |
+
PyPika==0.48.9
|
89 |
+
python-dateutil==2.8.2
|
90 |
+
python-dotenv==1.0.0
|
91 |
+
python-multipart==0.0.6
|
92 |
+
pytz==2023.3.post1
|
93 |
+
PyYAML==6.0.1
|
94 |
+
referencing==0.30.2
|
95 |
+
regex==2023.10.3
|
96 |
+
requests==2.31.0
|
97 |
+
rpds-py==0.10.6
|
98 |
+
safetensors==0.4.0
|
99 |
+
scikit-learn==1.3.1
|
100 |
+
scipy==1.11.3
|
101 |
+
semantic-version==2.10.0
|
102 |
+
sentence-transformers==2.2.2
|
103 |
+
sentencepiece==0.1.99
|
104 |
+
six==1.16.0
|
105 |
+
sniffio==1.3.0
|
106 |
+
SQLAlchemy==2.0.22
|
107 |
+
starlette==0.27.0
|
108 |
+
sympy==1.12
|
109 |
+
tenacity==8.2.3
|
110 |
+
threadpoolctl==3.2.0
|
111 |
+
tokenizers==0.14.1
|
112 |
+
toolz==0.12.0
|
113 |
+
torch==2.1.0
|
114 |
+
torchvision==0.16.0
|
115 |
+
tqdm==4.66.1
|
116 |
+
transformers==4.34.0
|
117 |
+
triton==2.1.0
|
118 |
+
typer==0.9.0
|
119 |
+
typing-inspect==0.9.0
|
120 |
+
typing_extensions==4.8.0
|
121 |
+
tzdata==2023.3
|
122 |
+
urllib3==2.0.6
|
123 |
+
uvicorn==0.23.2
|
124 |
+
uvloop==0.18.0
|
125 |
+
watchfiles==0.21.0
|
126 |
+
websockets==11.0.3
|
127 |
+
xformers==0.0.22.post4
|
128 |
+
yarl==1.9.2
|