USinoSiteAIChat / appOKed.py
binqiangliu's picture
Rename app.py to appOKed.py
0962da9
import streamlit as st
import sys
from langchain.chains import RetrievalQA
from langchain.document_loaders import WebBaseLoader
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate, LLMChain
from langchain import HuggingFaceHub
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader
from sentence_transformers.util import semantic_search
import requests
from pathlib import Path
from time import sleep
import torch
import os
import random
import string
from dotenv import load_dotenv
load_dotenv()
#from langchain.prompts.chat import (ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import timeit
import datetime
st.set_page_config(page_title="USinoIP Website AI Chat Assistant", layout="wide")
st.subheader("Welcome to USinoIP Website AI Chat Assistant.")
css_file = "main.css"
with open(css_file) as f:
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
model_id = os.getenv('model_id')
hf_token = os.getenv('hf_token')
repo_id = os.getenv('LLM_RepoID')
#HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
#model_id = os.environ.get('model_id')
#hf_token = os.environ.get('hf_token')
#repo_id = os.environ.get('repo_id')
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": f"Bearer {hf_token}"}
def get_embeddings(input_str_texts):
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
return response.json()
llm = HuggingFaceHub(repo_id=repo_id,
model_kwargs={"min_length":100,
"max_new_tokens":1024, "do_sample":True,
"temperature":0.1,
"top_k":50,
"top_p":0.95, "eos_token_id":49155})
prompt_template = """
You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
Your response should be full and detailed.
Question: {question}
Helpful AI Repsonse:
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
print(f"定义处理多余的Context文本的函数")
def remove_context(text):
# 检查 'Context:' 是否存在
if 'Context:' in text:
# 找到第一个 '\n\n' 的位置
end_of_context = text.find('\n\n')
# 删除 'Context:' 到第一个 '\n\n' 之间的部分
return text[end_of_context + 2:] # '+2' 是为了跳过两个换行符
else:
# 如果 'Context:' 不存在,返回原始文本
return text
print(f"处理多余的Context文本函数定义结束")
url="https://www.usinoip.com"
#url="https://www.usinoip.com/UpdatesAbroad/290.html"
if "url_loader" not in st.session_state:
st.session_state.url_loader = ""
if "raw_text" not in st.session_state:
st.session_state.raw_text = ""
if "initial_page_content" not in st.session_state:
st.session_state.initial_page_content = ""
if "final_page_content" not in st.session_state:
st.session_state.final_page_content = ""
if "texts" not in st.session_state:
st.session_state.texts = ""
#if "user_question" not in st.session_state:
# st.session_state.user_question = ""
if "initial_embeddings" not in st.session_state:
st.session_state.initial_embeddings = ""
if "db_embeddings" not in st.session_state:
st.session_state.db_embeddings = ""
#if "i_file_path" not in st.session_state:
# st.session_state.i_file_path = ""
i_file_path = ""
#if "file_path" not in st.session_state:
# st.session_state.file_path = ""
#if "random_string" not in st.session_state:
# st.session_state.random_string = ""
random_string = ""
wechat_image= "WeChatCode.jpg"
st.sidebar.markdown(
"""
<style>
.blue-underline {
text-decoration: bold;
color: blue;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
[data-testid=stSidebar] [data-testid=stImage]{
text-align: center;
display: block;
margin-left: auto;
margin-right: auto;
width: 50%;
}
</style>
""", unsafe_allow_html=True
)
user_question = st.text_input("Enter your query here and AI-Chat with your website:")
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 1000,
chunk_overlap = 200,
length_function = len,
)
with st.sidebar:
st.subheader("You are chatting with USinoIP official website.")
st.write("Note & Disclaimer: This app is provided on open source framework and is for information purpose only. NO guarantee is offered regarding information accuracy. NO liability could be claimed against whoever associated with this app in any manner. User should consult a qualified legal professional for legal advice.")
st.sidebar.markdown("Contact: [[email protected]](mailto:[email protected])")
st.sidebar.markdown('WeChat: <span class="blue-underline">pat2win</span>, or scan the code below.', unsafe_allow_html=True)
st.image(wechat_image)
st.subheader("Enjoy Chatting!")
st.sidebar.markdown('<span class="blue-underline">Life Enhancing with AI.</span>', unsafe_allow_html=True)
try:
with st.spinner("Preparing website materials for you..."):
st.session_state.url_loader = WebBaseLoader([url])
st.session_state.raw_text = st.session_state.url_loader.load()
st.session_state.initial_page_content = st.session_state.raw_text[0].page_content
st.session_state.final_page_content = str(st.session_state.initial_page_content)
st.session_state.temp_texts = text_splitter.split_text(st.session_state.final_page_content)
#Created a chunk of size 3431, which is longer than the specified 1000
st.session_state.texts = st.session_state.temp_texts
st.session_state.initial_embeddings=get_embeddings(st.session_state.texts)
st.session_state.db_embeddings = torch.FloatTensor(st.session_state.initial_embeddings)
print("DB Embeddings Ready.")
except Exception as e:
# st.write("Unknow error.")
# print("Please enter a valide URL.")
# st.stop()
pass
if st.button('Get AI Response'):
if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():
with st.spinner("AI Thinking...Please wait a while to Cheers!"):
q_embedding=get_embeddings(user_question)
final_q_embedding = torch.FloatTensor(q_embedding)
print("Question Embeddings Ready.")
hits = semantic_search(final_q_embedding, st.session_state.db_embeddings, top_k=5)
page_contents = []
for i in range(len(hits[0])):
page_content = st.session_state.texts[hits[0][i]['corpus_id']]
page_contents.append(page_content)
temp_page_contents=str(page_contents)
final_page_contents = temp_page_contents.replace('\\n', '')
random_string = generate_random_string(20)
i_file_path = random_string + ".txt"
with open(i_file_path, "w", encoding="utf-8") as file:
file.write(final_page_contents)
text_loader = TextLoader(i_file_path, encoding="utf-8")
loaded_documents = text_loader.load()
temp_ai_response=chain({"input_documents": loaded_documents, "question": user_question}, return_only_outputs=False)
initial_ai_response=temp_ai_response['output_text']
cleaned_initial_ai_response = remove_context(initial_ai_response)
final_ai_response = cleaned_initial_ai_response.split('<|end|>\n<|system|>\n<|end|>\n<|user|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
#temp_ai_response = temp_ai_response['output_text']
#final_ai_response=temp_ai_response.partition('<|end|>')[0]
#i_final_ai_response = final_ai_response.replace('\n', '')
print("AI Response:")
print(final_ai_response)
st.write("AI Response:")
st.write(final_ai_response)