Spaces:
Runtime error
Runtime error
# main.py | |
import logging | |
from fastapi import FastAPI, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
import nest_asyncio | |
from pyngrok import ngrok | |
import uvicorn | |
import json | |
from model import Model | |
from doc_reader import DocReader | |
from transformers import GenerationConfig, pipeline | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.schema.runnable import RunnableBranch | |
from langchain_core.runnables import RunnableLambda | |
import torch | |
# Logger configuration | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s [%(levelname)s] %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S') | |
logger = logging.getLogger(__name__) | |
import os | |
os.system("nvidia-smi") | |
print("TORCH_CUDA", torch.cuda.is_available()) | |
# Add path to sys | |
# sys.path.insert(0,'/opt/accelerate') | |
# sys.path.insert(0,'/opt/uvicorn') | |
# sys.path.insert(0,'/opt/pyngrok') | |
# sys.path.insert(0,'/opt/huggingface_hub') | |
# sys.path.insert(0,'/opt/nest_asyncio') | |
# sys.path.insert(0,'/opt/transformers') | |
# sys.path.insert(0,'/opt/pytorch') | |
# Initialize FastAPI app | |
app = FastAPI() | |
#NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token | |
#MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ" | |
#MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1" | |
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf" | |
PDF_PATH = "/opt/docs" | |
CLASSIFIER_MODEL_NAME = "roberta-large-mnli" | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
model_instance = Model(MODEL_NAME) | |
model_instance.load() | |
#model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE | |
# classifier_model = pipeline("zero-shot-classification", | |
# model=CLASSIFIER_MODEL_NAME) | |
async def predict_text(request: Request): | |
try: | |
# Parse request body as JSON | |
request_body = await request.json() | |
prompt = request_body.get("prompt", "") | |
# TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed | |
result = general_chain.invoke({"question":prompt}) | |
logger.info(f"Result: {result}") | |
formatted_response = { | |
"choices": [ | |
{ | |
"message": { | |
"content": result['result'] | |
} | |
} | |
] | |
} | |
return formatted_response | |
except json.JSONDecodeError: | |
return {"error": "Invalid JSON format"} | |
def load_pdfs(): | |
global db | |
doc_reader = DocReader(PDF_PATH) | |
# Load PDFs and convert to Markdown | |
pages = doc_reader.load_pdfs() | |
markdown_text = doc_reader.convert_to_markdown(pages) | |
texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts | |
# Generate embeddings | |
db = doc_reader.generate_embeddings(texts) | |
# def classify_sequence(input_data): | |
# sequence_to_classify = input_data["question"] | |
# candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse'] | |
# classification = classifier_model(sequence_to_classify, candidate_labels) | |
# # Extract the label with the highest score | |
# return {"topic": classification['labels'][0], "question": sequence_to_classify} | |
def format_output(output): | |
return {"result": output} | |
def setup_chain(): | |
#global full_chain | |
#global classifier_chain | |
global command_chain | |
#global support_chain | |
global general_chain | |
generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
generation_config.max_new_tokens = 1024 | |
generation_config.temperature = 0.3 | |
generation_config.top_p = 0.9 | |
generation_config.do_sample = True | |
generation_config.repetition_penalty = 1.15 | |
text_pipeline = pipeline( | |
"text-generation", | |
model=model_instance.model, | |
tokenizer=model_instance.tokenizer, | |
return_full_text=True, | |
generation_config=generation_config, | |
) | |
llm = HuggingFacePipeline(pipeline=text_pipeline) | |
# Classifier | |
#classifier_runnable = RunnableLambda(classify_sequence) | |
# Formatter | |
output_runnable = RunnableLambda(format_output) | |
# System Commands | |
command_template = """ | |
[INST] <<SYS>> | |
As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options: | |
- 'systemctl stop sbox-admin' | |
- 'systemctl start sbox-admin' | |
- 'systemctl restart sbox-admin' | |
Respond with the chosen command. If uncertain, reply with 'No command will be executed'. | |
<</SYS>> | |
question: | |
{question} | |
answer: | |
[/INST]""" | |
command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable ) | |
# Support | |
# support_template = """ | |
# [INST] <<SYS>> | |
# Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end. | |
# <</SYS>> | |
# {context} | |
# {question} | |
# answer: | |
# [/INST] | |
# """ | |
# General | |
general_template = """ | |
[INST] <<SYS>> | |
You are an advanced AI assistant designed to provide assistance with a wide range of queries. | |
Users may request you to assume various roles or perform diverse tasks | |
<</SYS>> | |
question: | |
{question} | |
answer: | |
[/INST]""" | |
general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable) | |
#support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"]) | |
#support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True) | |
# support_chain = RetrievalQA.from_chain_type( | |
# llm=llm, | |
# chain_type="stuff", | |
# #retriever=db.as_retriever(search_kwargs={"k": 3}), | |
# retriever=db.as_retriever(), | |
# input_key="question", | |
# return_source_documents=True, | |
# chain_type_kwargs={"prompt": support_prompt}, | |
# verbose=False | |
# ) | |
# logger.info("support chain loaded successfully.") | |
# branch = RunnableBranch( | |
# (lambda x: x == "command", command_chain), | |
# (lambda x: x == "support", support_chain), | |
# general_chain, # Default chain | |
# ) | |
# def route_classification(output): | |
# if output['topic'] == 'LinuxCommand': | |
# logger.info("Routing to command chain") | |
# return command_chain | |
# elif output['topic'] == 'TechnicalSupport': | |
# logger.info("Routing to support chain") | |
# return support_chain | |
# else: | |
# logger.info("Routing to general chain") | |
# return general_chain | |
# routing_runnable = RunnableLambda(route_classification) | |
# Full chain integration | |
#full_chain = classifier_runnable | routing_runnable | |
#logger.info("Full chain loaded successfully.") | |
return general_chain | |
############### | |
# launch once at startup | |
#load_pdfs() | |
setup_chain() | |
############### | |
#if __name__ == "__main__": | |
# if NGROK_TOKEN is not None: | |
# ngrok.set_auth_token(NGROK_TOKEN) | |
# ngrok_tunnel = ngrok.connect(8000) | |
# public_url = ngrok_tunnel.public_url | |
# print('Public URL:', public_url) | |
# print("You can use {}/predict to get the assistant result.".format(public_url)) | |
# logger.info("You can use {}/predict to get the assistant result.".format(public_url)) | |
#nest_asyncio.apply() | |
#uvicorn.run(app, port=8000) | |