import streamlit as st import sys from langchain.chains import RetrievalQA from langchain.document_loaders import WebBaseLoader from langchain.chains.question_answering import load_qa_chain from langchain import PromptTemplate, LLMChain from langchain import HuggingFaceHub from PyPDF2 import PdfReader from langchain.text_splitter import CharacterTextSplitter from langchain.document_loaders import TextLoader from sentence_transformers.util import semantic_search import requests from pathlib import Path from time import sleep import torch import os import random import string from dotenv import load_dotenv load_dotenv() #from langchain.prompts.chat import (ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate) from langchain.text_splitter import RecursiveCharacterTextSplitter import timeit import datetime st.set_page_config(page_title="USinoIP Website AI Chat Assistant", layout="wide") st.subheader("Welcome to USinoIP Website AI Chat Assistant.") css_file = "main.css" with open(css_file) as f: st.markdown("".format(f.read()), unsafe_allow_html=True) HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') model_id = os.getenv('model_id') hf_token = os.getenv('hf_token') repo_id = os.getenv('LLM_RepoID') #HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN') #model_id = os.environ.get('model_id') #hf_token = os.environ.get('hf_token') #repo_id = os.environ.get('repo_id') api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" headers = {"Authorization": f"Bearer {hf_token}"} def get_embeddings(input_str_texts): response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}}) return response.json() llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"min_length":100, "max_new_tokens":1024, "do_sample":True, "temperature":0.1, "top_k":50, "top_p":0.95, "eos_token_id":49155}) prompt_template = """ You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question. If you don't know the answer, just say that you don't know. DON'T try to make up an answer. Your response should be full and detailed. Question: {question} Helpful AI Repsonse: """ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT) def generate_random_string(length): letters = string.ascii_lowercase return ''.join(random.choice(letters) for i in range(length)) print(f"定义处理多余的Context文本的函数") def remove_context(text): # 检查 'Context:' 是否存在 if 'Context:' in text: # 找到第一个 '\n\n' 的位置 end_of_context = text.find('\n\n') # 删除 'Context:' 到第一个 '\n\n' 之间的部分 return text[end_of_context + 2:] # '+2' 是为了跳过两个换行符 else: # 如果 'Context:' 不存在,返回原始文本 return text print(f"处理多余的Context文本函数定义结束") url="https://www.usinoip.com" #url="https://www.usinoip.com/UpdatesAbroad/290.html" if "url_loader" not in st.session_state: st.session_state.url_loader = "" if "raw_text" not in st.session_state: st.session_state.raw_text = "" if "initial_page_content" not in st.session_state: st.session_state.initial_page_content = "" if "final_page_content" not in st.session_state: st.session_state.final_page_content = "" if "texts" not in st.session_state: st.session_state.texts = "" #if "user_question" not in st.session_state: # st.session_state.user_question = "" if "initial_embeddings" not in st.session_state: st.session_state.initial_embeddings = "" if "db_embeddings" not in st.session_state: st.session_state.db_embeddings = "" #if "i_file_path" not in st.session_state: # st.session_state.i_file_path = "" i_file_path = "" #if "file_path" not in st.session_state: # st.session_state.file_path = "" #if "random_string" not in st.session_state: # st.session_state.random_string = "" random_string = "" wechat_image= "WeChatCode.jpg" st.sidebar.markdown( """ """, unsafe_allow_html=True ) st.markdown( """ """, unsafe_allow_html=True ) user_question = st.text_input("Enter your query here and AI-Chat with your website:") text_splitter = CharacterTextSplitter( separator = "\n", chunk_size = 1000, chunk_overlap = 200, length_function = len, ) with st.sidebar: st.subheader("You are chatting with USinoIP official website.") st.write("Note & Disclaimer: This app is provided on open source framework and is for information purpose only. NO guarantee is offered regarding information accuracy. NO liability could be claimed against whoever associated with this app in any manner. User should consult a qualified legal professional for legal advice.") st.sidebar.markdown("Contact: [aichat101@foxmail.com](mailto:aichat101@foxmail.com)") st.sidebar.markdown('WeChat: pat2win, or scan the code below.', unsafe_allow_html=True) st.image(wechat_image) st.subheader("Enjoy Chatting!") st.sidebar.markdown('Life Enhancing with AI.', unsafe_allow_html=True) try: with st.spinner("Preparing website materials for you..."): st.session_state.url_loader = WebBaseLoader([url]) st.session_state.raw_text = st.session_state.url_loader.load() st.session_state.initial_page_content = st.session_state.raw_text[0].page_content st.session_state.final_page_content = str(st.session_state.initial_page_content) st.session_state.temp_texts = text_splitter.split_text(st.session_state.final_page_content) #Created a chunk of size 3431, which is longer than the specified 1000 st.session_state.texts = st.session_state.temp_texts st.session_state.initial_embeddings=get_embeddings(st.session_state.texts) st.session_state.db_embeddings = torch.FloatTensor(st.session_state.initial_embeddings) print("DB Embeddings Ready.") except Exception as e: # st.write("Unknow error.") # print("Please enter a valide URL.") # st.stop() pass if st.button('Get AI Response'): if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace(): with st.spinner("AI Thinking...Please wait a while to Cheers!"): q_embedding=get_embeddings(user_question) final_q_embedding = torch.FloatTensor(q_embedding) print("Question Embeddings Ready.") hits = semantic_search(final_q_embedding, st.session_state.db_embeddings, top_k=5) page_contents = [] for i in range(len(hits[0])): page_content = st.session_state.texts[hits[0][i]['corpus_id']] page_contents.append(page_content) temp_page_contents=str(page_contents) final_page_contents = temp_page_contents.replace('\\n', '') random_string = generate_random_string(20) i_file_path = random_string + ".txt" with open(i_file_path, "w", encoding="utf-8") as file: file.write(final_page_contents) text_loader = TextLoader(i_file_path, encoding="utf-8") loaded_documents = text_loader.load() temp_ai_response=chain({"input_documents": loaded_documents, "question": user_question}, return_only_outputs=False) initial_ai_response=temp_ai_response['output_text'] cleaned_initial_ai_response = remove_context(initial_ai_response) final_ai_response = cleaned_initial_ai_response.split('<|end|>\n<|system|>\n<|end|>\n<|user|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') #temp_ai_response = temp_ai_response['output_text'] #final_ai_response=temp_ai_response.partition('<|end|>')[0] #i_final_ai_response = final_ai_response.replace('\n', '') print("AI Response:") print(final_ai_response) st.write("AI Response:") st.write(final_ai_response)