NetsPresso_QA / run_ralm_netspresso_doc.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
"""Ask a question to the netspresso database."""
import json
import sys
import argparse
from typing import List
from langchain.chat_models import ChatOpenAI # for `gpt-3.5-turbo` & `gpt-4`
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import BaseRetriever, Document
import gradio as gr
from search_online import OnlineSearcher
# DEFAULT_QUESTION = "모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n예를들어 \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples."
DEFAULT_QUESTION = "Ask the Netspresso bot about model lightweighting and optimization.\nFor example \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples."
TEMPERATURE = 0
# manual arguments (FIXME)
args = argparse.Namespace
args.index_type = "hybrid"
args.index = (
"/root/indexes/docs-netspresso-ai/sparse,/root/indexes/docs-netspresso-ai/dense"
)
if isinstance(
args.index, tuple
): # black extension automatically convert long str to tuple
assert len(args.index) == 1
args.index = args.index[0]
args.encoder = "castorini/mdpr-question-nq"
args.device = "cuda:0"
args.alpha = 0.5
args.normalization = True
args.lang_abbr = "en"
args.K = 10
# initialize qabot
print("initialize NP doc retrieval bot")
RETRIEVER = OnlineSearcher(args)
class LangChainCustomRetrieverWrapper(BaseRetriever):
def __init__(self, args):
super().__init__()
# self.retriever = RETRIEVER # TODO. should be initialize from args
# self.args = args
print("Initialize LangChainCustomRetrieverWrapper, TODO: fix minor bug")
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get texts relevant for a query.
Args:
query: string to find relevant texts for
Returns:
List of relevant documents
"""
print(f"query = {query}")
# retrieve
# hits = self.retriever.search(query, self.args.K)
hits = RETRIEVER.search(
query, args.K
) # TODO: fix bug that BaseRetriever object cannot have extra field
# extract docs
results = [
{
"contents": json.loads(
# self.retriever.searcher.sparse_searcher.doc(hits[i].docid).raw() # TODO: fix bug that BaseRetriever object cannot have extra field
RETRIEVER.searcher.sparse_searcher.doc(hits[i].docid).raw()
)["contents"],
"docid": hits[i].docid,
}
for i in range(len(hits))
]
# make result list of Document object
return [
Document(
page_content=result["contents"], metadata={"source": result["docid"]}
)
for result in results
]
async def aget_relevant_documents(
self, query: str
) -> List[Document]: # abstractmethod
raise NotImplementedError
class RaLM:
def __init__(self, args):
self.args = args
self.initialize_ralm()
def initialize_ralm(self):
# initialize custom retriever
self.retriever = LangChainCustomRetrieverWrapper(self.args)
# prompt for RaLM
system_template = """Use the following pieces of context to answer the users question.
Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources.
Always try to generate answer from source.
----------------
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
llm = ChatOpenAI(model_name=self.args.model_name, temperature=TEMPERATURE)
self.chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True,
reduce_k_below_max_tokens=True,
chain_type_kwargs=chain_type_kwargs,
)
def run_chain(self, question, force_korean=False):
if force_korean:
question = f"{question} 본문을 참고해서 한글로 대답해줘"
result = self.chain({"question": question})
# postprocess
result["answer"] = self.postprocess(result["answer"])
if isinstance(result["sources"], str):
result["sources"] = self.postprocess(result["sources"])
result["sources"] = result["sources"].split(", ")
result["sources"] = [src.strip() for src in result["sources"]]
# print result
self.print_result(result)
return result
def print_result(
self, result
): # print result of RetrievalQAWithSourcesChain of langchain
print(f"Answer: {result['answer']}")
print(f"Sources: ")
print(result["sources"])
assert isinstance(result["sources"], list)
nSource = len(result["sources"])
for i in range(nSource):
source_title = result["sources"][i]
print(f"{source_title}: ")
if "source_documents" in result:
for j in range(len(result["source_documents"])):
if result["source_documents"][j].metadata["source"] == source_title:
print(result["source_documents"][j].page_content)
break
def postprocess(self, text):
# remove final parenthesis (bug with unknown cause)
if (
text.endswith(")")
or text.endswith("(")
or text.endswith("[")
or text.endswith("]")
):
text = text[:-1]
return text.strip()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Ask a question to the netspresso docs."
)
# General
# parser.add_argument(
# "--question",
# type=str,
# default=None,
# required=True,
# help="The question to ask for database",
# )
parser.add_argument(
"--model_name",
type=str,
default="gpt-3.5-turbo-16k-0613",
help="model name for openai api",
)
# Retriever: fixed arg for now
"""
parser.add_argument(
"--query_encoder_name_or_dir",
type=str,
default="princeton-nlp/densephrases-multi-query-multi",
help="query encoder name registered in huggingface model hub OR custom query encoder checkpoint directory",
)
parser.add_argument(
"--index_name",
type=str,
default="1048576_flat_OPQ96",
help="index name appended to index directory prefix",
)
"""
args = parser.parse_args()
# to prevent collision with DensePhrase native argparser
sys.argv = [sys.argv[0]]
# initialize class
app = RaLM(args)
def question_answer(question):
result = app.run_chain(question=question, force_korean=False)
return result[
"answer"
], "\n######################################################\n\n".join(
[
f"Source {idx}\n{doc.page_content}"
for idx, doc in enumerate(result["source_documents"])
]
)
# launch gradio
gr.Interface(
fn=question_answer,
inputs=gr.inputs.Textbox(default=DEFAULT_QUESTION, label="Question"),
outputs=[
gr.inputs.Textbox(default="", label="Bot response"),
gr.inputs.Textbox(default="", label="Search result used by bot"),
],
title="Netspresso Q&A bot",
theme="dark-grass",
description="Ask the Netspresso bot about model lightweighting and optimization.", # simplified version, hide detail version
# description="모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n\n retriever: BM25&mdpr-question-nq, generator: gpt-3.5-turbo-16k-0613 (API)",
).launch(share=True)