# main.py
import logging
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
from pyngrok import ngrok
import uvicorn
import json
from model import Model
from doc_reader import DocReader
from transformers import GenerationConfig, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.schema.runnable import RunnableBranch
from langchain_core.runnables import RunnableLambda
import torch
# Logger configuration
logging.basicConfig(level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
import os
os.system("nvidia-smi")
print("TORCH_CUDA", torch.cuda.is_available())
# Add path to sys
# sys.path.insert(0,'/opt/accelerate')
# sys.path.insert(0,'/opt/uvicorn')
# sys.path.insert(0,'/opt/pyngrok')
# sys.path.insert(0,'/opt/huggingface_hub')
# sys.path.insert(0,'/opt/nest_asyncio')
# sys.path.insert(0,'/opt/transformers')
# sys.path.insert(0,'/opt/pytorch')
# Initialize FastAPI app
app = FastAPI()
#NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token
#MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ"
#MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1"
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
PDF_PATH = "/opt/docs"
CLASSIFIER_MODEL_NAME = "roberta-large-mnli"
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
model_instance = Model(MODEL_NAME)
model_instance.load()
#model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE
# classifier_model = pipeline("zero-shot-classification",
# model=CLASSIFIER_MODEL_NAME)
@app.post("/predict")
async def predict_text(request: Request):
try:
# Parse request body as JSON
request_body = await request.json()
prompt = request_body.get("prompt", "")
# TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed
result = general_chain.invoke({"question":prompt})
logger.info(f"Result: {result}")
formatted_response = {
"choices": [
{
"message": {
"content": result['result']
}
}
]
}
return formatted_response
except json.JSONDecodeError:
return {"error": "Invalid JSON format"}
def load_pdfs():
global db
doc_reader = DocReader(PDF_PATH)
# Load PDFs and convert to Markdown
pages = doc_reader.load_pdfs()
markdown_text = doc_reader.convert_to_markdown(pages)
texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts
# Generate embeddings
db = doc_reader.generate_embeddings(texts)
# def classify_sequence(input_data):
# sequence_to_classify = input_data["question"]
# candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse']
# classification = classifier_model(sequence_to_classify, candidate_labels)
# # Extract the label with the highest score
# return {"topic": classification['labels'][0], "question": sequence_to_classify}
def format_output(output):
return {"result": output}
def setup_chain():
#global full_chain
#global classifier_chain
global command_chain
#global support_chain
global general_chain
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.max_new_tokens = 1024
generation_config.temperature = 0.3
generation_config.top_p = 0.9
generation_config.do_sample = True
generation_config.repetition_penalty = 1.15
text_pipeline = pipeline(
"text-generation",
model=model_instance.model,
tokenizer=model_instance.tokenizer,
return_full_text=True,
generation_config=generation_config,
)
llm = HuggingFacePipeline(pipeline=text_pipeline)
# Classifier
#classifier_runnable = RunnableLambda(classify_sequence)
# Formatter
output_runnable = RunnableLambda(format_output)
# System Commands
command_template = """
[INST] <>
As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options:
- 'systemctl stop sbox-admin'
- 'systemctl start sbox-admin'
- 'systemctl restart sbox-admin'
Respond with the chosen command. If uncertain, reply with 'No command will be executed'.
<>
question:
{question}
answer:
[/INST]"""
command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable )
# Support
# support_template = """
# [INST] <>
# Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end.
# <>
# {context}
# {question}
# answer:
# [/INST]
# """
# General
general_template = """
[INST] <>
You are an advanced AI assistant designed to provide assistance with a wide range of queries.
Users may request you to assume various roles or perform diverse tasks
<>
question:
{question}
answer:
[/INST]"""
general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable)
#support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"])
#support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True)
# support_chain = RetrievalQA.from_chain_type(
# llm=llm,
# chain_type="stuff",
# #retriever=db.as_retriever(search_kwargs={"k": 3}),
# retriever=db.as_retriever(),
# input_key="question",
# return_source_documents=True,
# chain_type_kwargs={"prompt": support_prompt},
# verbose=False
# )
# logger.info("support chain loaded successfully.")
# branch = RunnableBranch(
# (lambda x: x == "command", command_chain),
# (lambda x: x == "support", support_chain),
# general_chain, # Default chain
# )
# def route_classification(output):
# if output['topic'] == 'LinuxCommand':
# logger.info("Routing to command chain")
# return command_chain
# elif output['topic'] == 'TechnicalSupport':
# logger.info("Routing to support chain")
# return support_chain
# else:
# logger.info("Routing to general chain")
# return general_chain
# routing_runnable = RunnableLambda(route_classification)
# Full chain integration
#full_chain = classifier_runnable | routing_runnable
#logger.info("Full chain loaded successfully.")
return general_chain
###############
# launch once at startup
#load_pdfs()
setup_chain()
###############
#if __name__ == "__main__":
# if NGROK_TOKEN is not None:
# ngrok.set_auth_token(NGROK_TOKEN)
# ngrok_tunnel = ngrok.connect(8000)
# public_url = ngrok_tunnel.public_url
# print('Public URL:', public_url)
# print("You can use {}/predict to get the assistant result.".format(public_url))
# logger.info("You can use {}/predict to get the assistant result.".format(public_url))
#nest_asyncio.apply()
#uvicorn.run(app, port=8000)