AmeerH commited on
Commit
f1cccc9
·
1 Parent(s): 23f91d6

Upload 3 files

Browse files
Files changed (3) hide show
  1. MTM_Memoir_txt.txt +0 -0
  2. app.py +153 -0
  3. requirements.txt +128 -0
MTM_Memoir_txt.txt ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We are going to develop the code for the RAG here. This is going to be the first and the only attempt IA!!
2
+ # To create the POC
3
+
4
+ # ! We need to do the following,
5
+ # Convert the PDF to Embeddings and save it into a vector database.
6
+ # Load LLAMA 2
7
+ # Connect LLAMA 2 to the vector database.
8
+ # Ask Questions and give answers.
9
+
10
+
11
+ # ! LLAMA IS LOADED
12
+ import gradio as gr
13
+
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ from transformers import BitsAndBytesConfig
16
+ import torch
17
+ import json
18
+
19
+ from torch import cuda
20
+ import torch
21
+ import transformers
22
+ from time import time
23
+ import chromadb
24
+ from chromadb.config import Settings
25
+ from langchain.llms import huggingface_pipeline
26
+ from langchain.document_loaders import TextLoader
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain.embeddings import HuggingFaceEmbeddings
29
+ from langchain.chains import RetrievalQA
30
+ from langchain.vectorstores.chroma import Chroma
31
+
32
+
33
+ nf4_config = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_quant_type="nf4",
36
+ bnb_4bit_compute_dtype=torch.bfloat16
37
+ )
38
+
39
+ # Change the model path here to test any other model.
40
+ # model_path = 'training_date_02_10_2023_psql/final_merged_checkpoint'
41
+ model_path = 'Llama-13b-chat'
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ model_path,
45
+ local_files_only=True,
46
+ )
47
+
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ model_path,
50
+ local_files_only=True,
51
+ low_cpu_mem_usage=True,
52
+ device_map="auto",
53
+ offload_folder="offload/",
54
+ cache_dir="cache/",
55
+ quantization_config=nf4_config # forgot this on the first try so full model was loaded.
56
+ )
57
+
58
+ model_config = transformers.AutoConfig.from_pretrained(model_path)
59
+
60
+
61
+ # define query huggingface_pipeline
62
+ query_pipeline = transformers.pipeline(
63
+ "text-generation",
64
+ model=model,
65
+ tokenizer=tokenizer,
66
+ torch_dtype=torch.float16,
67
+ device_map="auto"
68
+ )
69
+
70
+ llm = huggingface_pipeline.HuggingFacePipeline(pipeline=query_pipeline)
71
+
72
+
73
+ # Ingestion of data using text loader
74
+ loader = TextLoader("MTM_Memoir_txt.txt", encoding="utf-8")
75
+ documents = loader.load()
76
+
77
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
78
+ all_splits = text_splitter.split_documents(documents)
79
+
80
+ # let's create the embeddings and store in vector store
81
+ model_name = "sentence-transformers/all-mpnet-base-v2"
82
+ model_kwargs = {"device": "cuda"}
83
+
84
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
85
+
86
+ # initialize chromadb
87
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="chroma_db")
88
+
89
+ # Initialize the chain
90
+ retriever = vectordb.as_retriever()
91
+ qa = RetrievalQA.from_chain_type(
92
+ llm=llm,
93
+ chain_type="stuff",
94
+ retriever=retriever,
95
+ verbose=True
96
+ )
97
+
98
+ # let's test the RAG
99
+ def test_rag(qa, query):
100
+ print(query)
101
+ result = qa.run(query)
102
+ print(f"Result \t {result}")
103
+
104
+ test_rag(qa, "Hello when were you born?")
105
+
106
+
107
+ def preprocess_query(query):
108
+
109
+ # load the query as a dict
110
+ # res = json.loads(str(query))
111
+
112
+ # human_language = res['human_language']
113
+
114
+ # SQL_TABLE_CONTEXT = "CREATE TABLE properties (address character, details characterstate character, property_type character, price integer, bedrooms integer, bathrooms integer, sqft integer)"
115
+
116
+ # INTRO = f"<s>[INST] <<SYS>> \
117
+ # You are a helpful, genius data scientist, who has access to a database that contains listing of properties in New York. Your job is to write PostgreSQL Query to fetch data based on User Request and Parameters. If you can't reply correctly just say that not enough information was provided \n\n<</SYS>>"
118
+ INTRO = f"<s>[INST] <<SYS>>You are former Malaysian Prime Minister Tun Dr Mahathir Mohamad.. A visionary leader \n\n<</SYS>>"
119
+ INSTRUCTION = f"### Instruction\n Respond to the following query by your subject {query} Just like yourself. \n\n"
120
+
121
+ RESPONSE = f"### Response:\n\n"
122
+
123
+ final_payload = INTRO + INSTRUCTION + RESPONSE
124
+ payload_length = len(final_payload)
125
+
126
+ return final_payload
127
+
128
+
129
+ def get_result(qa=qa, query = ""):
130
+ return qa.run(query)
131
+
132
+
133
+ def predict(query):
134
+ processed_query = preprocess_query(query=query)
135
+ result = get_result(query=processed_query)
136
+ return(result)
137
+
138
+
139
+ # ! The following will also work now! I mistakenly wrote ap_name insted of api_name in the submit_button.click()
140
+
141
+ with gr.Blocks() as sql_generator:
142
+ query = gr.Textbox(label="Query", placeholder='Ask the president?')
143
+
144
+ output = gr.Textbox(label="Output")
145
+ submit_button = gr.Button("Submit")
146
+ submit_button.click(fn=predict,
147
+ inputs=query,
148
+ outputs=output, api_name="predict"
149
+ )
150
+
151
+
152
+
153
+ sql_generator.launch()
requirements.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.8.6
4
+ aiosignal==1.3.1
5
+ altair==5.1.2
6
+ annotated-types==0.6.0
7
+ anyio==3.7.1
8
+ async-timeout==4.0.3
9
+ attrs==23.1.0
10
+ backoff==2.2.1
11
+ bcrypt==4.0.1
12
+ bitsandbytes==0.41.1
13
+ certifi==2023.7.22
14
+ charset-normalizer==3.3.0
15
+ chroma-hnswlib==0.7.3
16
+ chromadb==0.4.14
17
+ click==8.1.7
18
+ coloredlogs==15.0.1
19
+ contourpy==1.1.1
20
+ cycler==0.12.1
21
+ dataclasses-json==0.6.1
22
+ einops==0.7.0
23
+ exceptiongroup==1.1.3
24
+ fastapi==0.103.2
25
+ ffmpy==0.3.1
26
+ filelock==3.12.4
27
+ flatbuffers==23.5.26
28
+ fonttools==4.43.1
29
+ frozenlist==1.4.0
30
+ fsspec==2023.9.2
31
+ gradio==3.47.1
32
+ gradio_client==0.6.0
33
+ greenlet==3.0.0
34
+ grpcio==1.59.0
35
+ h11==0.14.0
36
+ httpcore==0.18.0
37
+ httptools==0.6.0
38
+ httpx==0.25.0
39
+ huggingface-hub==0.17.3
40
+ humanfriendly==10.0
41
+ idna==3.4
42
+ importlib-resources==6.1.0
43
+ Jinja2==3.1.2
44
+ joblib==1.3.2
45
+ jsonpatch==1.33
46
+ jsonpointer==2.4
47
+ jsonschema==4.19.1
48
+ jsonschema-specifications==2023.7.1
49
+ kiwisolver==1.4.5
50
+ langchain==0.0.315
51
+ langsmith==0.0.43
52
+ MarkupSafe==2.1.3
53
+ marshmallow==3.20.1
54
+ matplotlib==3.8.0
55
+ monotonic==1.6
56
+ mpmath==1.3.0
57
+ multidict==6.0.4
58
+ mypy-extensions==1.0.0
59
+ networkx==3.1
60
+ nltk==3.8.1
61
+ numpy==1.26.1
62
+ nvidia-cublas-cu12==12.1.3.1
63
+ nvidia-cuda-cupti-cu12==12.1.105
64
+ nvidia-cuda-nvrtc-cu12==12.1.105
65
+ nvidia-cuda-runtime-cu12==12.1.105
66
+ nvidia-cudnn-cu12==8.9.2.26
67
+ nvidia-cufft-cu12==11.0.2.54
68
+ nvidia-curand-cu12==10.3.2.106
69
+ nvidia-cusolver-cu12==11.4.5.107
70
+ nvidia-cusparse-cu12==12.1.0.106
71
+ nvidia-nccl-cu12==2.18.1
72
+ nvidia-nvjitlink-cu12==12.2.140
73
+ nvidia-nvtx-cu12==12.1.105
74
+ onnxruntime==1.16.1
75
+ orjson==3.9.9
76
+ overrides==7.4.0
77
+ packaging==23.2
78
+ pandas==2.1.1
79
+ Pillow==10.1.0
80
+ posthog==3.0.2
81
+ protobuf==4.24.4
82
+ psutil==5.9.6
83
+ pulsar-client==3.3.0
84
+ pydantic==2.4.2
85
+ pydantic_core==2.10.1
86
+ pydub==0.25.1
87
+ pyparsing==3.1.1
88
+ PyPika==0.48.9
89
+ python-dateutil==2.8.2
90
+ python-dotenv==1.0.0
91
+ python-multipart==0.0.6
92
+ pytz==2023.3.post1
93
+ PyYAML==6.0.1
94
+ referencing==0.30.2
95
+ regex==2023.10.3
96
+ requests==2.31.0
97
+ rpds-py==0.10.6
98
+ safetensors==0.4.0
99
+ scikit-learn==1.3.1
100
+ scipy==1.11.3
101
+ semantic-version==2.10.0
102
+ sentence-transformers==2.2.2
103
+ sentencepiece==0.1.99
104
+ six==1.16.0
105
+ sniffio==1.3.0
106
+ SQLAlchemy==2.0.22
107
+ starlette==0.27.0
108
+ sympy==1.12
109
+ tenacity==8.2.3
110
+ threadpoolctl==3.2.0
111
+ tokenizers==0.14.1
112
+ toolz==0.12.0
113
+ torch==2.1.0
114
+ torchvision==0.16.0
115
+ tqdm==4.66.1
116
+ transformers==4.34.0
117
+ triton==2.1.0
118
+ typer==0.9.0
119
+ typing-inspect==0.9.0
120
+ typing_extensions==4.8.0
121
+ tzdata==2023.3
122
+ urllib3==2.0.6
123
+ uvicorn==0.23.2
124
+ uvloop==0.18.0
125
+ watchfiles==0.21.0
126
+ websockets==11.0.3
127
+ xformers==0.0.22.post4
128
+ yarl==1.9.2