Spaces:
Runtime error
Runtime error
File size: 7,587 Bytes
df8bb52 6d69370 df8bb52 6d69370 df8bb52 ca57b4d df8bb52 ca57b4d df8bb52 a4afd07 df8bb52 a4afd07 df8bb52 a4afd07 df8bb52 ca57b4d df8bb52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# main.py
import logging
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
from pyngrok import ngrok
import uvicorn
import json
from model import Model
from doc_reader import DocReader
from transformers import GenerationConfig, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.schema.runnable import RunnableBranch
from langchain_core.runnables import RunnableLambda
import torch
# Logger configuration
logging.basicConfig(level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
import os
os.system("nvidia-smi")
print("TORCH_CUDA", torch.cuda.is_available())
# Add path to sys
# sys.path.insert(0,'/opt/accelerate')
# sys.path.insert(0,'/opt/uvicorn')
# sys.path.insert(0,'/opt/pyngrok')
# sys.path.insert(0,'/opt/huggingface_hub')
# sys.path.insert(0,'/opt/nest_asyncio')
# sys.path.insert(0,'/opt/transformers')
# sys.path.insert(0,'/opt/pytorch')
# Initialize FastAPI app
app = FastAPI()
#NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token
#MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ"
#MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1"
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
PDF_PATH = "/opt/docs"
CLASSIFIER_MODEL_NAME = "roberta-large-mnli"
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
model_instance = Model(MODEL_NAME)
model_instance.load()
#model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE
# classifier_model = pipeline("zero-shot-classification",
# model=CLASSIFIER_MODEL_NAME)
@app.post("/predict")
async def predict_text(request: Request):
try:
# Parse request body as JSON
request_body = await request.json()
prompt = request_body.get("prompt", "")
# TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed
result = general_chain.invoke({"question":prompt})
logger.info(f"Result: {result}")
formatted_response = {
"choices": [
{
"message": {
"content": result['result']
}
}
]
}
return formatted_response
except json.JSONDecodeError:
return {"error": "Invalid JSON format"}
def load_pdfs():
global db
doc_reader = DocReader(PDF_PATH)
# Load PDFs and convert to Markdown
pages = doc_reader.load_pdfs()
markdown_text = doc_reader.convert_to_markdown(pages)
texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts
# Generate embeddings
db = doc_reader.generate_embeddings(texts)
# def classify_sequence(input_data):
# sequence_to_classify = input_data["question"]
# candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse']
# classification = classifier_model(sequence_to_classify, candidate_labels)
# # Extract the label with the highest score
# return {"topic": classification['labels'][0], "question": sequence_to_classify}
def format_output(output):
return {"result": output}
def setup_chain():
#global full_chain
#global classifier_chain
global command_chain
#global support_chain
global general_chain
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.max_new_tokens = 1024
generation_config.temperature = 0.3
generation_config.top_p = 0.9
generation_config.do_sample = True
generation_config.repetition_penalty = 1.15
text_pipeline = pipeline(
"text-generation",
model=model_instance.model,
tokenizer=model_instance.tokenizer,
return_full_text=True,
generation_config=generation_config,
)
llm = HuggingFacePipeline(pipeline=text_pipeline)
# Classifier
#classifier_runnable = RunnableLambda(classify_sequence)
# Formatter
output_runnable = RunnableLambda(format_output)
# System Commands
command_template = """
[INST] <<SYS>>
As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options:
- 'systemctl stop sbox-admin'
- 'systemctl start sbox-admin'
- 'systemctl restart sbox-admin'
Respond with the chosen command. If uncertain, reply with 'No command will be executed'.
<</SYS>>
question:
{question}
answer:
[/INST]"""
command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable )
# Support
# support_template = """
# [INST] <<SYS>>
# Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end.
# <</SYS>>
# {context}
# {question}
# answer:
# [/INST]
# """
# General
general_template = """
[INST] <<SYS>>
You are an advanced AI assistant designed to provide assistance with a wide range of queries.
Users may request you to assume various roles or perform diverse tasks
<</SYS>>
question:
{question}
answer:
[/INST]"""
general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable)
#support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"])
#support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True)
# support_chain = RetrievalQA.from_chain_type(
# llm=llm,
# chain_type="stuff",
# #retriever=db.as_retriever(search_kwargs={"k": 3}),
# retriever=db.as_retriever(),
# input_key="question",
# return_source_documents=True,
# chain_type_kwargs={"prompt": support_prompt},
# verbose=False
# )
# logger.info("support chain loaded successfully.")
# branch = RunnableBranch(
# (lambda x: x == "command", command_chain),
# (lambda x: x == "support", support_chain),
# general_chain, # Default chain
# )
# def route_classification(output):
# if output['topic'] == 'LinuxCommand':
# logger.info("Routing to command chain")
# return command_chain
# elif output['topic'] == 'TechnicalSupport':
# logger.info("Routing to support chain")
# return support_chain
# else:
# logger.info("Routing to general chain")
# return general_chain
# routing_runnable = RunnableLambda(route_classification)
# Full chain integration
#full_chain = classifier_runnable | routing_runnable
#logger.info("Full chain loaded successfully.")
return general_chain
###############
# launch once at startup
#load_pdfs()
setup_chain()
###############
#if __name__ == "__main__":
# if NGROK_TOKEN is not None:
# ngrok.set_auth_token(NGROK_TOKEN)
# ngrok_tunnel = ngrok.connect(8000)
# public_url = ngrok_tunnel.public_url
# print('Public URL:', public_url)
# print("You can use {}/predict to get the assistant result.".format(public_url))
# logger.info("You can use {}/predict to get the assistant result.".format(public_url))
#nest_asyncio.apply()
#uvicorn.run(app, port=8000)
|