qatool / common.py
naotakigawa's picture
Upload 6 files
8b16906
raw
history blame
6.86 kB
import streamlit as st
import os
import pickle
import ipaddress
import tiktoken
from pathlib import Path
from streamlit import runtime
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.web.server.websocket_headers import _get_websocket_headers
from llama_index import SimpleDirectoryReader
from llama_index import Prompt
from llama_index.chat_engine import CondenseQuestionChatEngine;
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index import ServiceContext, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
from llama_index.constants import DEFAULT_CHUNK_OVERLAP
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.callbacks import CallbackManager
from llama_index.llms import OpenAI
from log import logger
# 接続元制御
ALLOW_IP_ADDRESS = os.environ["ALLOW_IP_ADDRESS"]
# Azure AD app registration details
CLIENT_ID = os.environ["CLIENT_ID"]
CLIENT_SECRET = os.environ["CLIENT_SECRET"]
TENANT_ID = os.environ["TENANT_ID"]
# Azure API
REDIRECT_URI = os.environ["REDIRECT_URI"]
AUTHORITY = f"https://login.microsoftonline.com/{TENANT_ID}"
SCOPES = ["openid", "profile", "User.Read"]
# 接続元IP取得
def get_remote_ip():
ctx = get_script_run_ctx()
session_info = runtime.get_instance().get_client(ctx.session_id)
headers = _get_websocket_headers()
return session_info.request.remote_ip, headers.get("X-Forwarded-For")
# 接続元IP許可判定
def is_allow_ip_address():
remote_ip, x_forwarded_for = get_remote_ip()
logger.info("remote_ip:"+remote_ip)
if x_forwarded_for is not None:
remote_ip = x_forwarded_for
# localhost
if remote_ip == "::1":
return True
# プライベートIP
ipaddr = ipaddress.IPv4Address(remote_ip)
logger.info("ipaddr:"+str(ipaddr))
if ipaddr.is_private:
return True
# その他(許可リスト判定)
return remote_ip in ALLOW_IP_ADDRESS
#ログインの確認
def check_login():
if not is_allow_ip_address():
st.title("HTTP 403 Forbidden")
st.stop()
if "login_token" not in st.session_state or not st.session_state.login_token:
st.warning("**ログインしてください**")
st.stop()
INDEX_NAME = os.environ["INDEX_NAME"]
PKL_NAME = os.environ["PKL_NAME"]
# デバッグ用
llm = OpenAI(model='gpt-3.5-turbo', temperature=0.8, max_tokens=256)
text_splitter = TokenTextSplitter(separator="。", chunk_size=1500
, chunk_overlap=DEFAULT_CHUNK_OVERLAP
, tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
node_parser = SimpleNodeParser(text_splitter=text_splitter)
custom_prompt = Prompt("""\
以下はこれまでの会話履歴と、ドキュメントを検索して回答する必要がある、ユーザーからの会話文です。
会話と新しい会話文に基づいて、検索クエリを作成します。
挨拶された場合、挨拶を返してください。
答えを知らない場合は、「わかりません」と回答してください。
全ての回答は日本語で行ってください。
会話履歴:
{chat_history}
新しい会話文:
{question}
Search query:
""")
chat_history = []
def setChatEngine():
callback_manager = CallbackManager([st.session_state.llama_debug_handler])
service_context = ServiceContext.from_defaults(llm=llm,node_parser=node_parser,callback_manager=callback_manager)
response_synthesizer = get_response_synthesizer(response_mode='refine')
st.session_state.query_engine = st.session_state.index.as_query_engine(
response_synthesizer=response_synthesizer,
service_context=service_context,
)
st.session_state.chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=st.session_state.query_engine,
condense_question_prompt=custom_prompt,
chat_history=chat_history,
verbose=True
)
# chat mode reacの記述
# from langchain.prompts.chat import (
# ChatPromptTemplate,
# HumanMessagePromptTemplate,
# SystemMessagePromptTemplate,
# )
# from llama_index.prompts.prompts import RefinePrompt, QuestionAnswerPrompt
# from llama_index.prompts import Prompt
# chat_text_qa_msgs = [
# SystemMessagePromptTemplate.from_template(
# "文脈が役に立たない場合でも、必ず日本語で質問に答えてください。"
# ),
# HumanMessagePromptTemplate.from_template(
# "以下に、コンテキスト情報を提供します。 \n"
# "---------------------\n"
# "{context_str}"
# "\n---------------------\n"
# "回答には以下を含めてください。\n"
# "・最初に問い合わせへのお礼してください\n"
# "・回答には出典のドキュメント名を含めるようにしてください。\n"
# "・質問内容を要約してください\n"
# "・最後に不明な点がないか確認してください \n"
# "この情報を踏まえて、次の質問に回答してください: {query_str}\n"
# "答えを知らない場合は、「わからない」と回答してください。また、必ず日本語で回答してください。"
# ),
# ]
# REFINE_PROMPT = ("元の質問は次のとおりです: {query_str} \n"
# "既存の回答を提供しました: {existing_answer} \n"
# "既存の答えを洗練する機会があります \n"
# "(必要な場合のみ)以下にコンテキストを追加します。 \n"
# "------------\n"
# "{context_msg}\n"
# "------------\n"
# "新しいコンテキストを考慮して、元の答えをより良く洗練して質問に答えてください。\n"
# "回答には出典のドキュメント名を含めるようにしてください。\n"
# "コンテキストが役に立たない場合は、元の回答と同じものを返します。"
# "どのような場合でも、返答は日本語で行います。")
# refine_prompt = RefinePrompt(REFINE_PROMPT)
# def setChatEngine():
# callback_manager = CallbackManager([st.session_state.llama_debug_handler])
# service_context = ServiceContext.from_defaults(node_parser=node_parser,callback_manager=callback_manager)
# response_synthesizer = get_response_synthesizer(response_mode='refine')
# st.session_state.chat_engine = st.session_state.index.as_chat_engine(
# response_synthesizer=response_synthesizer,
# service_context=service_context,
# chat_mode="condense_question",
# text_qa_template= Prompt.from_langchain_prompt(ChatPromptTemplate.from_messages(chat_text_qa_msgs)),
# refine_template=refine_prompt,
# verbose=True
# )