binqiangliu commited on
Commit
b622bb6
·
1 Parent(s): 0962da9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sys
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.document_loaders import WebBaseLoader
5
+ from langchain.chains.question_answering import load_qa_chain
6
+ from langchain import PromptTemplate, LLMChain
7
+ from langchain import HuggingFaceHub
8
+ from PyPDF2 import PdfReader
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from langchain.document_loaders import TextLoader
11
+ from sentence_transformers.util import semantic_search
12
+ import requests
13
+ from pathlib import Path
14
+ from time import sleep
15
+ import torch
16
+ import os
17
+ import random
18
+ import string
19
+ from dotenv import load_dotenv
20
+ load_dotenv()
21
+
22
+ #from langchain.prompts.chat import (ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate)
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
+ import timeit
25
+ import datetime
26
+
27
+ st.set_page_config(page_title="USinoIP Website AI Chat Assistant", layout="wide")
28
+ st.subheader("Welcome to USinoIP Website AI Chat Assistant.")
29
+
30
+ css_file = "main.css"
31
+ with open(css_file) as f:
32
+ st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
33
+
34
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
35
+ model_id = os.getenv('model_id')
36
+ hf_token = os.getenv('hf_token')
37
+ repo_id = os.getenv('LLM_RepoID')
38
+
39
+ #HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
40
+ #model_id = os.environ.get('model_id')
41
+ #hf_token = os.environ.get('hf_token')
42
+ #repo_id = os.environ.get('repo_id')
43
+
44
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
45
+ headers = {"Authorization": f"Bearer {hf_token}"}
46
+
47
+ def get_embeddings(input_str_texts):
48
+ response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
49
+ return response.json()
50
+
51
+ llm = HuggingFaceHub(repo_id=repo_id,
52
+ model_kwargs={"min_length":100,
53
+ "max_new_tokens":1024, "do_sample":True,
54
+ "temperature":0.1,
55
+ "top_k":50,
56
+ "top_p":0.95, "eos_token_id":49155})
57
+
58
+ prompt_template = """
59
+ You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
60
+ Your response should be full and detailed.
61
+ Question: {question}
62
+ Helpful AI Repsonse:
63
+ """
64
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
65
+ chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
66
+
67
+ def generate_random_string(length):
68
+ letters = string.ascii_lowercase
69
+ return ''.join(random.choice(letters) for i in range(length))
70
+
71
+ print(f"定义处理多余的Context文本的函数")
72
+ def remove_context(text):
73
+ # 检查 'Context:' 是否存在
74
+ if 'Context:' in text:
75
+ # 找到第一个 '\n\n' 的位置
76
+ end_of_context = text.find('\n\n')
77
+ # 删除 'Context:' 到第一个 '\n\n' 之间的部分
78
+ return text[end_of_context + 2:] # '+2' 是为了跳过两个换行符
79
+ else:
80
+ # 如果 'Context:' 不存在,返回原始文本
81
+ return text
82
+ print(f"处理多余的Context文本函数定义结束")
83
+
84
+ url="https://www.usinoip.com"
85
+ #url="https://www.usinoip.com/UpdatesAbroad/290.html"
86
+
87
+ if "url_loader" not in st.session_state:
88
+ st.session_state.url_loader = ""
89
+
90
+ if "raw_text" not in st.session_state:
91
+ st.session_state.raw_text = ""
92
+
93
+ if "initial_page_content" not in st.session_state:
94
+ st.session_state.initial_page_content = ""
95
+
96
+ if "final_page_content" not in st.session_state:
97
+ st.session_state.final_page_content = ""
98
+
99
+ if "texts" not in st.session_state:
100
+ st.session_state.texts = ""
101
+
102
+ #if "user_question" not in st.session_state:
103
+ # st.session_state.user_question = ""
104
+
105
+ if "initial_embeddings" not in st.session_state:
106
+ st.session_state.initial_embeddings = ""
107
+
108
+ if "db_embeddings" not in st.session_state:
109
+ st.session_state.db_embeddings = ""
110
+
111
+ #if "i_file_path" not in st.session_state:
112
+ # st.session_state.i_file_path = ""
113
+ i_file_path = ""
114
+
115
+ #if "file_path" not in st.session_state:
116
+ # st.session_state.file_path = ""
117
+
118
+ #if "random_string" not in st.session_state:
119
+ # st.session_state.random_string = ""
120
+ random_string = ""
121
+
122
+ wechat_image= "WeChatCode.jpg"
123
+
124
+ st.sidebar.markdown(
125
+ """
126
+ <style>
127
+ .blue-underline {
128
+ text-decoration: bold;
129
+ color: blue;
130
+ }
131
+ </style>
132
+ """,
133
+ unsafe_allow_html=True
134
+ )
135
+
136
+ st.markdown(
137
+ """
138
+ <style>
139
+ [data-testid=stSidebar] [data-testid=stImage]{
140
+ text-align: center;
141
+ display: block;
142
+ margin-left: auto;
143
+ margin-right: auto;
144
+ width: 50%;
145
+ }
146
+ </style>
147
+ """, unsafe_allow_html=True
148
+ )
149
+
150
+ user_question = st.text_input("Enter your query here and AI-Chat with your website:")
151
+
152
+ text_splitter = CharacterTextSplitter(
153
+ separator = "\n",
154
+ chunk_size = 1000,
155
+ chunk_overlap = 200,
156
+ length_function = len,
157
+ )
158
+
159
+ with st.sidebar:
160
+ st.subheader("You are chatting with USinoIP official website.")
161
+ st.write("Note & Disclaimer: This app is provided on open source framework and is for information purpose only. NO guarantee is offered regarding information accuracy. NO liability could be claimed against whoever associated with this app in any manner. User should consult a qualified legal professional for legal advice.")
162
+ st.sidebar.markdown("Contact: [[email protected]](mailto:[email protected])")
163
+ st.sidebar.markdown('WeChat: <span class="blue-underline">pat2win</span>, or scan the code below.', unsafe_allow_html=True)
164
+ st.image(wechat_image)
165
+ st.subheader("Enjoy Chatting!")
166
+ st.sidebar.markdown('<span class="blue-underline">Life Enhancing with AI.</span>', unsafe_allow_html=True)
167
+ if st.button('Get AI Response'):
168
+ try:
169
+ with st.spinner("Preparing website materials for you..."):
170
+ st.session_state.url_loader = WebBaseLoader([url])
171
+ st.session_state.raw_text = st.session_state.url_loader.load()
172
+ st.session_state.initial_page_content = st.session_state.raw_text[0].page_content
173
+ st.session_state.final_page_content = str(st.session_state.initial_page_content)
174
+ st.session_state.temp_texts = text_splitter.split_text(st.session_state.final_page_content)
175
+ #Created a chunk of size 3431, which is longer than the specified 1000
176
+ st.session_state.texts = st.session_state.temp_texts
177
+ st.session_state.initial_embeddings=get_embeddings(st.session_state.texts)
178
+ st.session_state.db_embeddings = torch.FloatTensor(st.session_state.initial_embeddings)
179
+ print("DB Embeddings Ready.")
180
+ except Exception as e:
181
+ st.write("Unknow error.")
182
+ print("Please enter a valide URL.")
183
+ st.stop()
184
+
185
+ if user_question !="" and not user_question.strip().isspace() and not user_question == "" and not user_question.strip() == "" and not user_question.isspace():
186
+ with st.spinner("AI Thinking...Please wait a while to Cheers!"):
187
+ q_embedding=get_embeddings(user_question)
188
+ final_q_embedding = torch.FloatTensor(q_embedding)
189
+ print("Question Embeddings Ready.")
190
+ hits = semantic_search(final_q_embedding, st.session_state.db_embeddings, top_k=5)
191
+ page_contents = []
192
+ for i in range(len(hits[0])):
193
+ page_content = st.session_state.texts[hits[0][i]['corpus_id']]
194
+ page_contents.append(page_content)
195
+ temp_page_contents=str(page_contents)
196
+ final_page_contents = temp_page_contents.replace('\\n', '')
197
+ random_string = generate_random_string(20)
198
+ i_file_path = random_string + ".txt"
199
+ with open(i_file_path, "w", encoding="utf-8") as file:
200
+ file.write(final_page_contents)
201
+ text_loader = TextLoader(i_file_path, encoding="utf-8")
202
+ loaded_documents = text_loader.load()
203
+ temp_ai_response=chain({"input_documents": loaded_documents, "question": user_question}, return_only_outputs=False)
204
+ initial_ai_response=temp_ai_response['output_text']
205
+ cleaned_initial_ai_response = remove_context(initial_ai_response)
206
+ final_ai_response = cleaned_initial_ai_response.split('<|end|>\n<|system|>\n<|end|>\n<|user|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
207
+ #temp_ai_response = temp_ai_response['output_text']
208
+ #final_ai_response=temp_ai_response.partition('<|end|>')[0]
209
+ #i_final_ai_response = final_ai_response.replace('\n', '')
210
+ print("AI Response:")
211
+ print(final_ai_response)
212
+ st.write("AI Response:")
213
+ st.write(final_ai_response)