# main.py import logging from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware import nest_asyncio from pyngrok import ngrok import uvicorn import json from model import Model from doc_reader import DocReader from transformers import GenerationConfig, pipeline from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.schema.runnable import RunnableBranch from langchain_core.runnables import RunnableLambda import torch # Logger configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger(__name__) import os os.system("nvidia-smi") print("TORCH_CUDA", torch.cuda.is_available()) # Add path to sys # sys.path.insert(0,'/opt/accelerate') # sys.path.insert(0,'/opt/uvicorn') # sys.path.insert(0,'/opt/pyngrok') # sys.path.insert(0,'/opt/huggingface_hub') # sys.path.insert(0,'/opt/nest_asyncio') # sys.path.insert(0,'/opt/transformers') # sys.path.insert(0,'/opt/pytorch') # Initialize FastAPI app app = FastAPI() #NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token #MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ" #MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1" MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf" PDF_PATH = "/opt/docs" CLASSIFIER_MODEL_NAME = "roberta-large-mnli" # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) model_instance = Model(MODEL_NAME) model_instance.load() #model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE # classifier_model = pipeline("zero-shot-classification", # model=CLASSIFIER_MODEL_NAME) @app.post("/predict") async def predict_text(request: Request): try: # Parse request body as JSON request_body = await request.json() prompt = request_body.get("prompt", "") # TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed result = general_chain.invoke({"question":prompt}) logger.info(f"Result: {result}") formatted_response = { "choices": [ { "message": { "content": result['result'] } } ] } return formatted_response except json.JSONDecodeError: return {"error": "Invalid JSON format"} def load_pdfs(): global db doc_reader = DocReader(PDF_PATH) # Load PDFs and convert to Markdown pages = doc_reader.load_pdfs() markdown_text = doc_reader.convert_to_markdown(pages) texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts # Generate embeddings db = doc_reader.generate_embeddings(texts) # def classify_sequence(input_data): # sequence_to_classify = input_data["question"] # candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse'] # classification = classifier_model(sequence_to_classify, candidate_labels) # # Extract the label with the highest score # return {"topic": classification['labels'][0], "question": sequence_to_classify} def format_output(output): return {"result": output} def setup_chain(): #global full_chain #global classifier_chain global command_chain #global support_chain global general_chain generation_config = GenerationConfig.from_pretrained(MODEL_NAME) generation_config.max_new_tokens = 1024 generation_config.temperature = 0.3 generation_config.top_p = 0.9 generation_config.do_sample = True generation_config.repetition_penalty = 1.15 text_pipeline = pipeline( "text-generation", model=model_instance.model, tokenizer=model_instance.tokenizer, return_full_text=True, generation_config=generation_config, ) llm = HuggingFacePipeline(pipeline=text_pipeline) # Classifier #classifier_runnable = RunnableLambda(classify_sequence) # Formatter output_runnable = RunnableLambda(format_output) # System Commands command_template = """ [INST] <> As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options: - 'systemctl stop sbox-admin' - 'systemctl start sbox-admin' - 'systemctl restart sbox-admin' Respond with the chosen command. If uncertain, reply with 'No command will be executed'. <> question: {question} answer: [/INST]""" command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable ) # Support # support_template = """ # [INST] <> # Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end. # <> # {context} # {question} # answer: # [/INST] # """ # General general_template = """ [INST] <> You are an advanced AI assistant designed to provide assistance with a wide range of queries. Users may request you to assume various roles or perform diverse tasks <> question: {question} answer: [/INST]""" general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable) #support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"]) #support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True) # support_chain = RetrievalQA.from_chain_type( # llm=llm, # chain_type="stuff", # #retriever=db.as_retriever(search_kwargs={"k": 3}), # retriever=db.as_retriever(), # input_key="question", # return_source_documents=True, # chain_type_kwargs={"prompt": support_prompt}, # verbose=False # ) # logger.info("support chain loaded successfully.") # branch = RunnableBranch( # (lambda x: x == "command", command_chain), # (lambda x: x == "support", support_chain), # general_chain, # Default chain # ) # def route_classification(output): # if output['topic'] == 'LinuxCommand': # logger.info("Routing to command chain") # return command_chain # elif output['topic'] == 'TechnicalSupport': # logger.info("Routing to support chain") # return support_chain # else: # logger.info("Routing to general chain") # return general_chain # routing_runnable = RunnableLambda(route_classification) # Full chain integration #full_chain = classifier_runnable | routing_runnable #logger.info("Full chain loaded successfully.") return general_chain ############### # launch once at startup #load_pdfs() setup_chain() ############### #if __name__ == "__main__": # if NGROK_TOKEN is not None: # ngrok.set_auth_token(NGROK_TOKEN) # ngrok_tunnel = ngrok.connect(8000) # public_url = ngrok_tunnel.public_url # print('Public URL:', public_url) # print("You can use {}/predict to get the assistant result.".format(public_url)) # logger.info("You can use {}/predict to get the assistant result.".format(public_url)) #nest_asyncio.apply() #uvicorn.run(app, port=8000)