Shawn732 commited on
Commit
df8bb52
·
1 Parent(s): 5bc5847

1st Init Commit

Browse files
Files changed (7) hide show
  1. Dockerfile +59 -0
  2. doc_reader.py +53 -0
  3. main.py +232 -0
  4. model.py +83 -0
  5. requirements.txt +15 -0
  6. start_server.sh +22 -0
  7. streamlit_app.py +23 -0
Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an NVIDIA CUDA base image
2
+ ARG CUDA_IMAGE="12.1.1-devel-ubuntu22.04"
3
+ FROM nvidia/cuda:${CUDA_IMAGE}
4
+
5
+ ENV HOST 0.0.0.0
6
+
7
+ # Set the working directory in the container to /app
8
+ #WORKDIR /app
9
+
10
+ RUN mkdir -p /app/cache && chmod -R 777 /app/cache
11
+
12
+ ENV HF_HOME=/app/cache
13
+
14
+ # Install Python and pip
15
+ RUN apt-get update && apt-get upgrade -y \
16
+ && apt-get install -y git build-essential \
17
+ python3 python3-pip gcc wget \
18
+ ocl-icd-opencl-dev opencl-headers clinfo \
19
+ libclblast-dev libopenblas-dev \
20
+ && mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
21
+
22
+ ENV CUDA_DOCKER_ARCH=all
23
+ ENV LLAMA_CUBLAS=1
24
+
25
+ # Copy the current directory contents into the container at /app
26
+ COPY . /app
27
+
28
+ # Install required packages from requirements.txt
29
+ COPY ./requirements.txt /app/requirements.txt
30
+ RUN pip3 install --no-cache-dir -r /app/requirements.txt
31
+
32
+ # Expose the ports for FastAPI and Streamlit
33
+ EXPOSE 8000
34
+ EXPOSE 8501
35
+
36
+ RUN useradd -m -u 1000 user
37
+ # Switch to the "user" user
38
+ USER user
39
+ WORKDIR /home/user/app
40
+ # Set home to the user's home directory
41
+ ENV HOME=/home/user \
42
+ PATH=/home/user/.local/bin:$PATH \
43
+ PYTHONPATH=$HOME/app \
44
+ PYTHONUNBUFFERED=1 \
45
+ GRADIO_ALLOW_FLAGGING=never \
46
+ GRADIO_NUM_PORTS=1 \
47
+ GRADIO_SERVER_NAME=0.0.0.0 \
48
+ GRADIO_THEME=huggingface \
49
+ SYSTEM=spaces
50
+
51
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
52
+ COPY --chown=user . /home/user/app
53
+
54
+ # Copy and give execute permissions to the start script
55
+ COPY start_server.sh /app/start_server.sh
56
+ RUN chmod +x /app/start_server.sh
57
+
58
+ # Run the start script
59
+ CMD ["/app/start_server.sh"]
doc_reader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import Qdrant
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain_core.documents.base import Document
7
+
8
+ class DocReader:
9
+ def __init__(self, pdf_path, model_path="sentence-transformers/all-mpnet-base-v2", persist_directory="db"):
10
+ self.pdfs = glob.glob(f"{pdf_path}/*.pdf") # Adjusted to get all PDF files in the folder
11
+ self.model_path = model_path
12
+ self.persist_directory = persist_directory
13
+
14
+ def load_pdfs(self):
15
+ all_pages = []
16
+ for pdf_file in self.pdfs:
17
+ loader = PyPDFLoader(pdf_file)
18
+ pages = loader.load()
19
+ all_pages.extend(pages)
20
+ return all_pages
21
+
22
+ def convert_to_markdown(self, documents):
23
+ markdown_text = ""
24
+ for doc in documents:
25
+
26
+ page_text = doc.page_content.replace('\n', '\n\n') # Add extra newline for Markdown
27
+ markdown_text += page_text + "\n\n---\n\n"
28
+ return markdown_text
29
+
30
+ def split_text(self, pages):
31
+ text_splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size=128,
33
+ chunk_overlap=24)
34
+ documents = [Document(page_content=page) for page in pages]
35
+ split_documents = text_splitter.split_documents(documents)
36
+ texts = [doc.page_content for doc in split_documents]
37
+
38
+ return texts
39
+
40
+ def generate_embeddings(self, texts):
41
+ embeddings = HuggingFaceEmbeddings(
42
+ model_name=self.model_path,
43
+ model_kwargs={"device": "cuda:0"},
44
+ encode_kwargs={"normalize_embeddings": True},
45
+ )
46
+ documents = [Document(page_content=text) for text in texts]
47
+
48
+ db = Qdrant.from_documents(documents, embeddings, location=":memory:", collection_name="pdf_collection")
49
+ return db
50
+
51
+ def search_similar(self, input_text, k=3):
52
+ results = self.db.similarity_search(input_text, k)
53
+ return results
main.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import logging
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import nest_asyncio
6
+ from pyngrok import ngrok
7
+ import uvicorn
8
+ import json
9
+ from model import Model
10
+ from doc_reader import DocReader
11
+ from transformers import GenerationConfig, pipeline
12
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains import RetrievalQA
15
+ from langchain.schema.runnable import RunnableBranch
16
+ from langchain_core.runnables import RunnableLambda
17
+
18
+
19
+ # Logger configuration
20
+ logging.basicConfig(level=logging.INFO,
21
+ format='%(asctime)s [%(levelname)s] %(message)s',
22
+ datefmt='%Y-%m-%d %H:%M:%S')
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Add path to sys
26
+ # sys.path.insert(0,'/opt/accelerate')
27
+ # sys.path.insert(0,'/opt/uvicorn')
28
+ # sys.path.insert(0,'/opt/pyngrok')
29
+ # sys.path.insert(0,'/opt/huggingface_hub')
30
+ # sys.path.insert(0,'/opt/nest_asyncio')
31
+ # sys.path.insert(0,'/opt/transformers')
32
+ # sys.path.insert(0,'/opt/pytorch')
33
+
34
+ # Initialize FastAPI app
35
+ app = FastAPI()
36
+ NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token
37
+ #MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ"
38
+ #MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1"
39
+ MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
40
+ PDF_PATH = "/opt/docs"
41
+ CLASSIFIER_MODEL_NAME = "roberta-large-mnli"
42
+
43
+ # Add CORS middleware
44
+ app.add_middleware(
45
+ CORSMiddleware,
46
+ allow_origins=['*'],
47
+ allow_credentials=True,
48
+ allow_methods=['*'],
49
+ allow_headers=['*'],
50
+ )
51
+
52
+ model_instance = Model(MODEL_NAME)
53
+ model_instance.load()
54
+ #model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE
55
+
56
+ # classifier_model = pipeline("zero-shot-classification",
57
+ # model=CLASSIFIER_MODEL_NAME)
58
+
59
+
60
+ @app.post("/predict")
61
+ async def predict_text(request: Request):
62
+ try:
63
+ # Parse request body as JSON
64
+ request_body = await request.json()
65
+
66
+ prompt = request_body.get("prompt", "")
67
+ # TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed
68
+ result = general_chain.invoke({"question":prompt})
69
+ logger.info(f"Result: {result}")
70
+ formatted_response = {
71
+ "choices": [
72
+ {
73
+ "message": {
74
+ "content": result['result']
75
+ }
76
+ }
77
+ ]
78
+ }
79
+ return formatted_response
80
+ except json.JSONDecodeError:
81
+ return {"error": "Invalid JSON format"}
82
+
83
+ def load_pdfs():
84
+ global db
85
+ doc_reader = DocReader(PDF_PATH)
86
+ # Load PDFs and convert to Markdown
87
+ pages = doc_reader.load_pdfs()
88
+ markdown_text = doc_reader.convert_to_markdown(pages)
89
+ texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts
90
+ # Generate embeddings
91
+ db = doc_reader.generate_embeddings(texts)
92
+
93
+ # def classify_sequence(input_data):
94
+ # sequence_to_classify = input_data["question"]
95
+ # candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse']
96
+ # classification = classifier_model(sequence_to_classify, candidate_labels)
97
+ # # Extract the label with the highest score
98
+ # return {"topic": classification['labels'][0], "question": sequence_to_classify}
99
+
100
+ def format_output(output):
101
+ return {"result": output}
102
+
103
+ def setup_chain():
104
+ #global full_chain
105
+ #global classifier_chain
106
+ global command_chain
107
+ #global support_chain
108
+ global general_chain
109
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
110
+ generation_config.max_new_tokens = 1024
111
+ generation_config.temperature = 0.3
112
+ generation_config.top_p = 0.9
113
+ generation_config.do_sample = True
114
+ generation_config.repetition_penalty = 1.15
115
+
116
+ text_pipeline = pipeline(
117
+ "text-generation",
118
+ model=model_instance.model,
119
+ tokenizer=model_instance.tokenizer,
120
+ return_full_text=True,
121
+ generation_config=generation_config,
122
+ )
123
+
124
+ llm = HuggingFacePipeline(pipeline=text_pipeline)
125
+
126
+ # Classifier
127
+ #classifier_runnable = RunnableLambda(classify_sequence)
128
+ # Formatter
129
+ output_runnable = RunnableLambda(format_output)
130
+
131
+ # System Commands
132
+ command_template = """
133
+ [INST] <<SYS>>
134
+ As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options:
135
+ - 'systemctl stop sbox-admin'
136
+ - 'systemctl start sbox-admin'
137
+ - 'systemctl restart sbox-admin'
138
+ Respond with the chosen command. If uncertain, reply with 'No command will be executed'.
139
+ <</SYS>>
140
+ question:
141
+ {question}
142
+ answer:
143
+ [/INST]"""
144
+ command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable )
145
+
146
+ # Support
147
+ # support_template = """
148
+ # [INST] <<SYS>>
149
+ # Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end.
150
+ # <</SYS>>
151
+ # {context}
152
+ # {question}
153
+ # answer:
154
+ # [/INST]
155
+ # """
156
+
157
+
158
+ # General
159
+ general_template = """
160
+ [INST] <<SYS>>
161
+ You are an advanced AI assistant designed to provide assistance with a wide range of queries.
162
+ Users may request you to assume various roles or perform diverse tasks
163
+ <</SYS>>
164
+ question:
165
+ {question}
166
+ answer:
167
+ [/INST]"""
168
+ general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable)
169
+
170
+ #support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"])
171
+
172
+ #support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True)
173
+
174
+ # support_chain = RetrievalQA.from_chain_type(
175
+ # llm=llm,
176
+ # chain_type="stuff",
177
+ # #retriever=db.as_retriever(search_kwargs={"k": 3}),
178
+ # retriever=db.as_retriever(),
179
+ # input_key="question",
180
+ # return_source_documents=True,
181
+ # chain_type_kwargs={"prompt": support_prompt},
182
+ # verbose=False
183
+ # )
184
+ # logger.info("support chain loaded successfully.")
185
+
186
+ # branch = RunnableBranch(
187
+ # (lambda x: x == "command", command_chain),
188
+ # (lambda x: x == "support", support_chain),
189
+ # general_chain, # Default chain
190
+ # )
191
+
192
+ # def route_classification(output):
193
+ # if output['topic'] == 'LinuxCommand':
194
+ # logger.info("Routing to command chain")
195
+ # return command_chain
196
+ # elif output['topic'] == 'TechnicalSupport':
197
+ # logger.info("Routing to support chain")
198
+ # return support_chain
199
+ # else:
200
+ # logger.info("Routing to general chain")
201
+ # return general_chain
202
+
203
+ # routing_runnable = RunnableLambda(route_classification)
204
+
205
+ # Full chain integration
206
+ #full_chain = classifier_runnable | routing_runnable
207
+
208
+ #logger.info("Full chain loaded successfully.")
209
+ return general_chain
210
+
211
+
212
+ ###############
213
+ # launch once at startup
214
+ #load_pdfs()
215
+ setup_chain()
216
+ ###############
217
+
218
+ if __name__ == "__main__":
219
+
220
+ if NGROK_TOKEN is not None:
221
+ ngrok.set_auth_token(NGROK_TOKEN)
222
+
223
+ ngrok_tunnel = ngrok.connect(8000)
224
+ public_url = ngrok_tunnel.public_url
225
+
226
+ print('Public URL:', public_url)
227
+ print("You can use {}/predict to get the assistant result.".format(public_url))
228
+ logger.info("You can use {}/predict to get the assistant result.".format(public_url))
229
+
230
+ nest_asyncio.apply()
231
+ uvicorn.run(app, port=8000)
232
+
model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import logging
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
5
+
6
+ # Logger configuration
7
+ logging.basicConfig(level=logging.INFO,
8
+ format='%(asctime)s [%(levelname)s] %(message)s',
9
+ datefmt='%Y-%m-%d %H:%M:%S')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ #model_path = "/opt/Llama-2-13B-chat-GPTQ"
13
+
14
+ class Model:
15
+ def __init__(self, model_path):
16
+ self.model_name = model_path
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.loaded = False
20
+
21
+ def load(self, precision='fp16'):
22
+ try:
23
+ # Check if CUDA is available
24
+ if not torch.cuda.is_available():
25
+ raise EnvironmentError("CUDA not available.")
26
+ # Set precision settings
27
+ if precision == 'fp16':
28
+ torch_dtype = torch.float16
29
+ else:
30
+ torch_dtype = torch.float32
31
+
32
+ # Initialize tokenizer
33
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
34
+
35
+ # Set up model configuration
36
+ config = AutoConfig.from_pretrained(self.model_name)
37
+
38
+ #config.quantization_config["disable_exllama"] = False
39
+ #config.quantization_config["use_exllama"] = True
40
+ #config.quantization_config["exllama_config"] = {"version": 2}
41
+
42
+ # Load model with configuration and precision
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ self.model_name,
45
+ config=config,
46
+ device_map="cuda:0", # Set to GPU 0
47
+ torch_dtype=torch_dtype
48
+ )
49
+
50
+ self.loaded = True
51
+ logger.info(f"Model loaded successfully on GPU with {precision} precision.")
52
+ except Exception as e:
53
+ logger.error(f"Error loading model: {e}")
54
+
55
+ def predict(self, input_text, max_length=50):
56
+ if not self.loaded:
57
+ logger.error("Model not loaded. Please load the model before prediction.")
58
+ return None
59
+
60
+ logger.info("========== Start Prediction ==========")
61
+ try:
62
+ # Ensure the input_text is a string
63
+ if not isinstance(input_text, str):
64
+ raise ValueError("Input text must be a string.")
65
+
66
+ # Encoding the input text
67
+ input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
68
+
69
+ # Move input to the same device as model
70
+ input_ids = input_ids.to(next(self.model.parameters()).device)
71
+
72
+ # Generating output using the model
73
+ outputs = self.model.generate(input_ids, max_length=max_length)
74
+
75
+ # Decoding and returning the generated text
76
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
+ logger.info("Response: {}".format(response))
78
+ except Exception as e:
79
+ logger.error(f"Error during prediction: {e}")
80
+ response = None
81
+
82
+ logger.info("========== End Prediction ==========")
83
+ return response
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ nest-asyncio
3
+ pyngrok
4
+ uvicorn
5
+ accelerate
6
+ transformers
7
+ sentence-transformers
8
+ torch
9
+ auto-gptq
10
+ optimum
11
+ huggingface_hub
12
+ langchain
13
+ pypdf
14
+ qdrant-client
15
+ streamlit
start_server.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start FastAPI app
4
+ echo "Starting FastAPI app..."
5
+ #uvicorn main:app --reload &
6
+ python3 main.py &
7
+ # Store FastAPI process ID
8
+ FASTAPI_PID=$!
9
+
10
+ # Start Streamlit app
11
+ echo "Starting Streamlit app..."
12
+ streamlit run streamlit_app.py &
13
+
14
+ # Store Streamlit process ID
15
+ STREAMLIT_PID=$!
16
+
17
+ # Wait for any process to exit
18
+ wait -n
19
+
20
+ # Kill the other process when one exits
21
+ kill -TERM $FASTAPI_PID
22
+ kill -TERM $STREAMLIT_PID
streamlit_app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+
4
+ # Run this with: streamlit run streamlit_app.py
5
+ # Streamlit interface
6
+ st.title("Gemini Central Console Bot")
7
+ user_input = st.text_input("Enter your text here")
8
+ url = "http://localhost:8000/predict" # URL of your FastAPI predict endpoint
9
+
10
+ if st.button("Submit"):
11
+ # Prepare the payload
12
+ payload = {"prompt": user_input}
13
+
14
+ # Send the request to FastAPI endpoint
15
+ response = requests.post(url, json=payload)
16
+
17
+ # Display the response
18
+ if response.status_code == 200:
19
+ result = response.json()
20
+ content = result["choices"][0]["message"]["content"]
21
+ st.write(content)
22
+ else:
23
+ st.write("Failed to get response")