File size: 8,989 Bytes
9d613db
f96612a
9d613db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f96612a
 
 
 
 
9d613db
 
 
 
 
 
 
 
 
 
09018bc
4f4b4ad
9d613db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f4b4ad
 
9d613db
 
 
 
 
 
 
 
 
60ab24b
 
 
 
 
 
 
 
 
 
 
 
 
9d613db
fe99578
6831b20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d613db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6831b20
 
9d613db
6831b20
9d613db
 
 
 
 
 
 
 
 
 
 
572c0de
6831b20
 
 
 
 
 
 
 
 
 
042e63c
44f4dbd
6831b20
 
 
 
9d613db
2a536a8
 
 
 
 
042e63c
6831b20
2a536a8
 
6831b20
2a536a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0962da9
 
2a536a8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import streamlit as st
import sys
from langchain.chains import RetrievalQA
from langchain.document_loaders import WebBaseLoader
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate, LLMChain
from langchain import HuggingFaceHub
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader
from sentence_transformers.util import semantic_search
import requests
from pathlib import Path
from time import sleep
import torch
import os
import random
import string
from dotenv import load_dotenv
load_dotenv()

#from langchain.prompts.chat import (ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import timeit
import datetime

st.set_page_config(page_title="USinoIP Website AI Chat Assistant", layout="wide")
st.subheader("Welcome to USinoIP Website AI Chat Assistant.")

css_file = "main.css"
with open(css_file) as f:
    st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)

HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
model_id = os.getenv('model_id')
hf_token = os.getenv('hf_token')
repo_id = os.getenv('LLM_RepoID')

#HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
#model_id = os.environ.get('model_id')
#hf_token = os.environ.get('hf_token')
#repo_id = os.environ.get('repo_id')

api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": f"Bearer {hf_token}"}

def get_embeddings(input_str_texts):
    response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
    return response.json()

llm = HuggingFaceHub(repo_id=repo_id,
                     model_kwargs={"min_length":100,
                                   "max_new_tokens":1024, "do_sample":True,
                                   "temperature":0.1,
                                   "top_k":50,
                                   "top_p":0.95, "eos_token_id":49155})

prompt_template = """
You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
Your response should be full and detailed.
Question: {question}
Helpful AI Repsonse:
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)

def generate_random_string(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))  

print(f"定义处理多余的Context文本的函数")
def remove_context(text):
    # 检查 'Context:' 是否存在
    if 'Context:' in text:
        # 找到第一个 '\n\n' 的位置
        end_of_context = text.find('\n\n')
        # 删除 'Context:' 到第一个 '\n\n' 之间的部分
        return text[end_of_context + 2:]  # '+2' 是为了跳过两个换行符
    else:
        # 如果 'Context:' 不存在,返回原始文本
        return text
print(f"处理多余的Context文本函数定义结束")         
    
url="https://www.usinoip.com"
#url="https://www.usinoip.com/UpdatesAbroad/290.html"

if "url_loader" not in st.session_state:
    st.session_state.url_loader = ""
    
if "raw_text" not in st.session_state:
    st.session_state.raw_text = ""

if "initial_page_content" not in st.session_state:
    st.session_state.initial_page_content = ""   
    
if "final_page_content" not in st.session_state:
    st.session_state.final_page_content = ""     

if "texts" not in st.session_state:
    st.session_state.texts = ""

#if "user_question" not in st.session_state:
#    st.session_state.user_question = ""

if "initial_embeddings" not in st.session_state:
    st.session_state.initial_embeddings = ""

if "db_embeddings" not in st.session_state:
    st.session_state.db_embeddings = ""

#if "i_file_path" not in st.session_state:
#    st.session_state.i_file_path = ""
i_file_path = ""    

#if "file_path" not in st.session_state:
#    st.session_state.file_path = ""

#if "random_string" not in st.session_state:
#    st.session_state.random_string = ""
random_string = ""    
    
wechat_image= "WeChatCode.jpg"

st.sidebar.markdown(
    """
    <style>
    .blue-underline {
        text-decoration: bold;
        color: blue;
    }
    </style>
    """,
    unsafe_allow_html=True
)

st.markdown(
    """
    <style>
        [data-testid=stSidebar] [data-testid=stImage]{
            text-align: center;
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 50%;
        }
    </style>
    """, unsafe_allow_html=True
)

user_question = st.text_input("Enter your query here and AI-Chat with your website:")

text_splitter = CharacterTextSplitter(        
    separator = "\n",
    chunk_size = 1000,
    chunk_overlap  = 200,
    length_function = len,
)

with st.sidebar:
    st.subheader("You are chatting with USinoIP official website.")
    st.write("Note & Disclaimer: This app is provided on open source framework and is for information purpose only. NO guarantee is offered regarding information accuracy. NO liability could be claimed against whoever associated with this app in any manner. User should consult a qualified legal professional for legal advice.")
    st.sidebar.markdown("Contact: [[email protected]](mailto:[email protected])")
    st.sidebar.markdown('WeChat: <span class="blue-underline">pat2win</span>, or scan the code below.', unsafe_allow_html=True)
    st.image(wechat_image)
    st.subheader("Enjoy Chatting!")
    st.sidebar.markdown('<span class="blue-underline">Life Enhancing with AI.</span>', unsafe_allow_html=True)
    try:
        with st.spinner("Preparing website materials for you..."): 
            st.session_state.url_loader = WebBaseLoader([url])
            st.session_state.raw_text = st.session_state.url_loader.load()
            st.session_state.initial_page_content = st.session_state.raw_text[0].page_content
            st.session_state.final_page_content = str(st.session_state.initial_page_content)
            st.session_state.temp_texts = text_splitter.split_text(st.session_state.final_page_content)
            #Created a chunk of size 3431, which is longer than the specified 1000
            st.session_state.texts = st.session_state.temp_texts
            st.session_state.initial_embeddings=get_embeddings(st.session_state.texts)
            st.session_state.db_embeddings = torch.FloatTensor(st.session_state.initial_embeddings) 
            print("DB Embeddings Ready.")
    except Exception as e:
    #    st.write("Unknow error.")
    #    print("Please enter a valide URL.")
    #    st.stop()
        pass
          
if st.button('Get AI Response'):
    if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():
        with st.spinner("AI Thinking...Please wait a while to Cheers!"):
            q_embedding=get_embeddings(user_question)
            final_q_embedding = torch.FloatTensor(q_embedding)  
            print("Question Embeddings Ready.")
            hits = semantic_search(final_q_embedding, st.session_state.db_embeddings, top_k=5)
            page_contents = []
            for i in range(len(hits[0])):
                page_content = st.session_state.texts[hits[0][i]['corpus_id']]
                page_contents.append(page_content)
            temp_page_contents=str(page_contents)
            final_page_contents = temp_page_contents.replace('\\n', '')     
            random_string = generate_random_string(20)
            i_file_path = random_string + ".txt"
            with open(i_file_path, "w", encoding="utf-8") as file:
                file.write(final_page_contents)
            text_loader = TextLoader(i_file_path, encoding="utf-8")
            loaded_documents = text_loader.load()
            temp_ai_response=chain({"input_documents": loaded_documents, "question": user_question}, return_only_outputs=False)
            initial_ai_response=temp_ai_response['output_text']
            cleaned_initial_ai_response = remove_context(initial_ai_response)
            final_ai_response = cleaned_initial_ai_response.split('<|end|>\n<|system|>\n<|end|>\n<|user|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
            #temp_ai_response = temp_ai_response['output_text']    
            #final_ai_response=temp_ai_response.partition('<|end|>')[0]
            #i_final_ai_response = final_ai_response.replace('\n', '')
            print("AI Response:")
            print(final_ai_response)
            st.write("AI Response:")
            st.write(final_ai_response)