File size: 7,587 Bytes
df8bb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d69370
df8bb52
 
 
 
 
 
 
6d69370
 
 
 
 
df8bb52
 
 
 
 
 
 
 
 
 
 
ca57b4d
df8bb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca57b4d
df8bb52
a4afd07
 
df8bb52
a4afd07
 
df8bb52
a4afd07
 
 
df8bb52
ca57b4d
 
df8bb52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# main.py
import logging
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
from pyngrok import ngrok
import uvicorn
import json
from model import Model
from doc_reader import DocReader
from transformers import GenerationConfig, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.schema.runnable import RunnableBranch
from langchain_core.runnables import RunnableLambda
import torch

# Logger configuration
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s [%(levelname)s] %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)


import os
os.system("nvidia-smi")
print("TORCH_CUDA", torch.cuda.is_available())

# Add path to sys
# sys.path.insert(0,'/opt/accelerate')
# sys.path.insert(0,'/opt/uvicorn')
# sys.path.insert(0,'/opt/pyngrok')
# sys.path.insert(0,'/opt/huggingface_hub')
# sys.path.insert(0,'/opt/nest_asyncio')
# sys.path.insert(0,'/opt/transformers')
# sys.path.insert(0,'/opt/pytorch')

# Initialize FastAPI app
app = FastAPI()
#NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ"  # Replace with your NGROK token
#MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ"
#MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1"
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
PDF_PATH = "/opt/docs"
CLASSIFIER_MODEL_NAME = "roberta-large-mnli"

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

model_instance = Model(MODEL_NAME)
model_instance.load()
#model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE

# classifier_model = pipeline("zero-shot-classification",
#                       model=CLASSIFIER_MODEL_NAME)


@app.post("/predict")
async def predict_text(request: Request):
    try:
        # Parse request body as JSON
        request_body = await request.json()

        prompt = request_body.get("prompt", "")
        # TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed
        result = general_chain.invoke({"question":prompt})
        logger.info(f"Result: {result}")
        formatted_response = {
            "choices": [
                {
                    "message": {
                        "content": result['result']
                    }
                }
            ]
        }
        return formatted_response
    except json.JSONDecodeError:
        return {"error": "Invalid JSON format"}

def load_pdfs():
  global db
  doc_reader = DocReader(PDF_PATH)
  # Load PDFs and convert to Markdown
  pages = doc_reader.load_pdfs()
  markdown_text = doc_reader.convert_to_markdown(pages)
  texts = doc_reader.split_text([markdown_text])  # Assuming split_text now takes a list of Markdown texts
  # Generate embeddings
  db = doc_reader.generate_embeddings(texts)

# def classify_sequence(input_data):
#     sequence_to_classify = input_data["question"]
#     candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse']
#     classification = classifier_model(sequence_to_classify, candidate_labels)
#     # Extract the label with the highest score
#     return {"topic": classification['labels'][0], "question": sequence_to_classify}

def format_output(output):
    return {"result": output}

def setup_chain():
  #global full_chain
  #global classifier_chain
  global command_chain
  #global support_chain
  global general_chain
  generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
  generation_config.max_new_tokens = 1024
  generation_config.temperature = 0.3
  generation_config.top_p = 0.9
  generation_config.do_sample = True
  generation_config.repetition_penalty = 1.15

  text_pipeline = pipeline(
    "text-generation",
    model=model_instance.model,
    tokenizer=model_instance.tokenizer,
    return_full_text=True,
    generation_config=generation_config,
  )

  llm = HuggingFacePipeline(pipeline=text_pipeline)
  
  # Classifier
  #classifier_runnable = RunnableLambda(classify_sequence)
  # Formatter
  output_runnable = RunnableLambda(format_output)

  # System Commands
  command_template = """
  [INST] <<SYS>>
  As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options:
  - 'systemctl stop sbox-admin'
  - 'systemctl start sbox-admin'
  - 'systemctl restart sbox-admin'
  Respond with the chosen command. If uncertain, reply with 'No command will be executed'.
  <</SYS>>
  question:
  {question}
  answer:
  [/INST]"""
  command_chain =  (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable )

  # Support
#   support_template = """
#   [INST] <<SYS>>
#   Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end.
#   <</SYS>>
#   {context}
#   {question}
#   answer:
#   [/INST]
#   """


  # General
  general_template = """
  [INST] <<SYS>>
  You are an advanced AI assistant designed to provide assistance with a wide range of queries. 
  Users may request you to assume various roles or perform diverse tasks
  <</SYS>>
  question:
  {question}
  answer:
  [/INST]"""
  general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable)

  #support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"])

  #support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True)

#   support_chain = RetrievalQA.from_chain_type(
#       llm=llm,
#       chain_type="stuff",
#       #retriever=db.as_retriever(search_kwargs={"k": 3}),
#       retriever=db.as_retriever(),
#       input_key="question",
#       return_source_documents=True,
#       chain_type_kwargs={"prompt": support_prompt},
#       verbose=False
#   )
#   logger.info("support chain loaded successfully.")

  # branch = RunnableBranch(
  #     (lambda x: x == "command", command_chain),
  #     (lambda x: x == "support", support_chain),
  #     general_chain,  # Default chain
  # )

#   def route_classification(output):
#     if output['topic'] == 'LinuxCommand':
#         logger.info("Routing to command chain")
#         return command_chain
#     elif output['topic'] == 'TechnicalSupport':
#         logger.info("Routing to support chain")
#         return support_chain
#     else:
#         logger.info("Routing to general chain")
#         return general_chain

#   routing_runnable = RunnableLambda(route_classification)

  # Full chain integration
  #full_chain = classifier_runnable | routing_runnable

  #logger.info("Full chain loaded successfully.")
  return general_chain


###############
# launch once at startup
#load_pdfs()
setup_chain()
###############

#if __name__ == "__main__":

    # if NGROK_TOKEN is not None:
    #     ngrok.set_auth_token(NGROK_TOKEN)

    # ngrok_tunnel = ngrok.connect(8000)
    # public_url = ngrok_tunnel.public_url

    # print('Public URL:', public_url)
    # print("You can use {}/predict to get the assistant result.".format(public_url))
    # logger.info("You can use {}/predict to get the assistant result.".format(public_url))

    #nest_asyncio.apply()
    #uvicorn.run(app, port=8000)