qatool / common.py
naotakigawa's picture
Upload 3 files
0a9b238
raw
history blame
6.29 kB
import streamlit as st
import os
import pickle
import ipaddress
import tiktoken
from pathlib import Path
from streamlit import runtime
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.web.server.websocket_headers import _get_websocket_headers
from llama_index import SimpleDirectoryReader
# from llama_index import Prompt
from llama_index.prompts.base import PromptTemplate
from llama_index.chat_engine import CondenseQuestionChatEngine;
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index import ServiceContext, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
from llama_index.constants import DEFAULT_CHUNK_OVERLAP
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.callbacks import CallbackManager
from llama_index.llms import OpenAI
from log import logger
from llama_index.llms.base import ChatMessage, MessageRole
from llama_index.prompts.base import ChatPromptTemplate
# 接続元制御
ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
# Azure AD app registration details
CLIENT_ID = os.environ["CLIENT_ID"]
CLIENT_SECRET = os.environ["CLIENT_SECRET"]
TENANT_ID = os.environ["TENANT_ID"]
# Azure API
REDIRECT_URI = os.environ["REDIRECT_URI"]
AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
SCOPES = ["openid", "profile", "User.Read"]
# 接続元IP取得
def get_remote_ip():
ctx = get_script_run_ctx()
session_info = runtime.get_instance().get_client(ctx.session_id)
headers = _get_websocket_headers()
return session_info.request.remote_ip, headers.get("X-Forwarded-For")
# 接続元IP許可判定
def is_allow_ip_address():
remote_ip, x_forwarded_for = get_remote_ip()
logger.info("remote_ip:"+remote_ip)
if x_forwarded_for is not None:
remote_ip = x_forwarded_for
# localhost
if remote_ip == "::1":
return True
# プライベートIP
ipaddr = ipaddress.IPv4Address(remote_ip)
logger.info("ipaddr:"+str(ipaddr))
if ipaddr.is_private:
return True
# その他(許可リスト判定)
return remote_ip in ALLOW_IP_ADDRESS
#ログインの確認
def check_login():
if not is_allow_ip_address():
st.title("HTTP 403 Forbidden")
st.stop()
if "login_token" not in st.session_state or not st.session_state.login_token:
st.warning("**ログインしてください**")
st.stop()
INDEX_NAME = os.environ["INDEX_NAME"]
PKL_NAME = os.environ["PKL_NAME"]
# デバッグ用
llm = OpenAI(model='gpt-3.5-turbo', temperature=0.8, max_tokens=256)
text_splitter = TokenTextSplitter(separator="。", chunk_size=1500
, chunk_overlap=DEFAULT_CHUNK_OVERLAP
, tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
node_parser = SimpleNodeParser(text_splitter=text_splitter)
custom_prompt = PromptTemplate("""\
以下はこれまでの会話履歴と、ドキュメントを検索して回答する必要がある、ユーザーからの会話文です。
会話と新しい会話文に基づいて、検索クエリを作成します。
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
""")
TEXT_QA_SYSTEM_PROMPT = ChatMessage(
content=(
"あなたは世界中で信頼されているQAシステムです。\n"
"事前知識ではなく、常に提供されたコンテキスト情報を使用してクエリに回答してください。\n"
"従うべきいくつかのルール:\n"
"1. 回答内で指定されたコンテキストを直接参照しないでください。\n"
"2. 「コンテキストに基づいて、...」や「コンテキスト情報は...」、またはそれに類するような記述は避けてください。"
),
role=MessageRole.SYSTEM,
)
# QAプロンプトテンプレートメッセージ
TEXT_QA_PROMPT_TMPL_MSGS = [
TEXT_QA_SYSTEM_PROMPT,
ChatMessage(
content=(
"コンテキスト情報は以下のとおりです。\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"事前知識ではなくコンテキスト情報を考慮して、クエリに答えます。\n"
"Query: {query_str}\n"
"Answer: "
),
role=MessageRole.USER,
),
]
CHAT_TEXT_QA_PROMPT = ChatPromptTemplate(message_templates=TEXT_QA_PROMPT_TMPL_MSGS)
CHAT_REFINE_PROMPT_TMPL_MSGS = [
ChatMessage(
content=(
"あなたは、既存の回答を改良する際に2つのモードで厳密に動作するQAシステムのエキスパートです。\n"
"1. 新しいコンテキストを使用して元の回答を**書き直す**。\n"
"2. 新しいコンテキストが役に立たない場合は、元の回答を**繰り返す**。\n"
"回答内で元の回答やコンテキストを直接参照しないでください。\n"
"疑問がある場合は、元の答えを繰り返してください。"
"New Context: {context_msg}\n"
"Query: {query_str}\n"
"Original Answer: {existing_answer}\n"
"New Answer: "
),
role=MessageRole.USER,
)
]
# チャットRefineプロンプト
CHAT_REFINE_PROMPT = ChatPromptTemplate(message_templates=CHAT_REFINE_PROMPT_TMPL_MSGS)
def setChatEngine():
callback_manager = CallbackManager([st.session_state.llama_debug_handler])
service_context = ServiceContext.from_defaults(llm=llm,node_parser=node_parser,callback_manager=callback_manager)
response_synthesizer = get_response_synthesizer(
response_mode='refine',
text_qa_template= CHAT_TEXT_QA_PROMPT,
refine_template=CHAT_REFINE_PROMPT,
)
st.session_state.query_engine = st.session_state.index.as_query_engine(
response_synthesizer=response_synthesizer,
service_context=service_context,
)
st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=st.session_state.query_engine,
condense_question_prompt=custom_prompt,
verbose=True
)