binqiangliu commited on
Commit
9d613db
·
1 Parent(s): f1d59ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import streamlit as st
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.document_loaders import WebBaseLoader
5
+ from langchain.chains.question_answering import load_qa_chain
6
+ from langchain import PromptTemplate, LLMChain
7
+ from langchain import HuggingFaceHub
8
+ from PyPDF2 import PdfReader
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from langchain.document_loaders import TextLoader
11
+ from sentence_transformers.util import semantic_search
12
+ import requests
13
+ from pathlib import Path
14
+ from time import sleep
15
+ import torch
16
+ import os
17
+ import random
18
+ import string
19
+ from dotenv import load_dotenv
20
+ load_dotenv()
21
+
22
+ st.set_page_config(page_title="USinoIP Website AI Chat Assistant", layout="wide")
23
+ st.subheader("Welcome to USinoIP Website AI Chat Assistant.")
24
+
25
+ css_file = "main.css"
26
+ with open(css_file) as f:
27
+ st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
28
+
29
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
30
+ model_id = os.getenv('model_id')
31
+ hf_token = os.getenv('hf_token')
32
+ repo_id = os.getenv('repo_id')
33
+ #HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
34
+ #model_id = os.environ.get('model_id')
35
+ #hf_token = os.environ.get('hf_token')
36
+ #repo_id = os.environ.get('repo_id')
37
+
38
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
39
+ headers = {"Authorization": f"Bearer {hf_token}"}
40
+
41
+ def get_embeddings(input_str_texts):
42
+ response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
43
+ return response.json()
44
+
45
+ llm = HuggingFaceHub(repo_id=repo_id,
46
+ model_kwargs={"min_length":100,
47
+ "max_new_tokens":1024, "do_sample":True,
48
+ "temperature":0.1,
49
+ "top_k":50,
50
+ "top_p":0.95, "eos_token_id":49155})
51
+
52
+ prompt_template = """
53
+ #You are a very helpful AI assistant. Please ONLY use {context} to answer the user's input question. If you don't know the answer, just say that you don't know. DON'T try to make up an answer and do NOT go beyond the given context without the user's explicitly asking you to do so!
54
+ You are a very helpful AI assistant. Please response to the user's input question with as many details as possible.
55
+ Question: {question}
56
+ Helpful AI Repsonse:
57
+ """
58
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
59
+ chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
60
+
61
+ def generate_random_string(length):
62
+ letters = string.ascii_lowercase
63
+ return ''.join(random.choice(letters) for i in range(length))
64
+
65
+ #url="https://www.usinoip.com"
66
+ url="https://www.usinoip.com/UpdatesAbroad/290.html"
67
+ texts=""
68
+ raw_text=""
69
+ user_question = ""
70
+ initial_embeddings=""
71
+ db_embeddings = ""
72
+ i_file_path=""
73
+ file_path = ""
74
+ random_string=""
75
+ wechat_image= "WeChatCode.jpg"
76
+
77
+ st.sidebar.markdown(
78
+ """
79
+ <style>
80
+ .blue-underline {
81
+ text-decoration: bold;
82
+ color: blue;
83
+ }
84
+ </style>
85
+ """,
86
+ unsafe_allow_html=True
87
+ )
88
+
89
+ st.markdown(
90
+ """
91
+ <style>
92
+ [data-testid=stSidebar] [data-testid=stImage]{
93
+ text-align: center;
94
+ display: block;
95
+ margin-left: auto;
96
+ margin-right: auto;
97
+ width: 50%;
98
+ }
99
+ </style>
100
+ """, unsafe_allow_html=True
101
+ )
102
+
103
+ user_question = st.text_input("Enter your query here and AI-Chat with your website:")
104
+
105
+ text_splitter = CharacterTextSplitter(
106
+ separator = "\n",
107
+ chunk_size = 1000,
108
+ chunk_overlap = 200,
109
+ length_function = len,
110
+ )
111
+
112
+ with st.sidebar:
113
+ st.subheader("You are chatting with USinoIP official website.")
114
+ st.write("Note & Disclaimer: This app is provided on open source framework and is for information purpose only. NO guarantee is offered regarding information accuracy. NO liability could be claimed against whoever associated with this app in any manner. User should consult a qualified legal professional for legal advice.")
115
+ st.sidebar.markdown("Contact: [[email protected]](mailto:[email protected])")
116
+ st.sidebar.markdown('WeChat: <span class="blue-underline">pat2win</span>, or scan the code below.', unsafe_allow_html=True)
117
+ st.image(wechat_image)
118
+ st.subheader("Enjoy Chatting!")
119
+ st.sidebar.markdown('<span class="blue-underline">Life Enhancing with AI.</span>', unsafe_allow_html=True)
120
+ with st.spinner("Preparing website materials for you..."):
121
+ try:
122
+ loader = WebBaseLoader(url)
123
+ raw_text = loader.load()
124
+ page_content = raw_text[0].page_content
125
+ page_content = str(page_content)
126
+ temp_texts = text_splitter.split_text(page_content)
127
+ texts = temp_texts
128
+ initial_embeddings=get_embeddings(texts)
129
+ db_embeddings = torch.FloatTensor(initial_embeddings)
130
+ except Exception as e:
131
+ st.write("Unknow error.")
132
+ print("Please enter a valide URL.")
133
+ st.stop()
134
+
135
+ if user_question.strip().isspace() or user_question.isspace():
136
+ st.write("Query Empty. Please enter a valid query first.")
137
+ st.stop()
138
+ elif user_question == "exit":
139
+ st.stop()
140
+ elif user_question == "":
141
+ print("Query Empty. Please enter a valid query first.")
142
+ st.stop()
143
+ elif user_question != "":
144
+ #st.write("Your query: "+user_question)
145
+ print("Your query: "+user_question)
146
+
147
+ with st.spinner("AI Thinking...Please wait a while to Cheers!"):
148
+ q_embedding=get_embeddings(user_question)
149
+ final_q_embedding = torch.FloatTensor(q_embedding)
150
+ hits = semantic_search(final_q_embedding, db_embeddings, top_k=5)
151
+ page_contents = []
152
+ for i in range(len(hits[0])):
153
+ page_content = texts[hits[0][i]['corpus_id']]
154
+ page_contents.append(page_content)
155
+ temp_page_contents=str(page_contents)
156
+ final_page_contents = temp_page_contents.replace('\\n', '')
157
+ random_string = generate_random_string(20)
158
+ i_file_path = random_string + ".txt"
159
+ with open(i_file_path, "w", encoding="utf-8") as file:
160
+ file.write(final_page_contents)
161
+ loader = TextLoader(i_file_path, encoding="utf-8")
162
+ loaded_documents = loader.load()
163
+ temp_ai_response=chain({"input_documents": loaded_documents, "question": user_question}, return_only_outputs=False)
164
+ temp_ai_response = temp_ai_response['output_text']
165
+ final_ai_response=temp_ai_response.partition('<|end|>')[0]
166
+ i_final_ai_response = final_ai_response.replace('\n', '')
167
+ st.write("AI Response:")
168
+ st.write(i_final_ai_response)