Upload app.py
Browse files
app.py
CHANGED
@@ -1,87 +1,46 @@
|
|
1 |
-
import torch
|
2 |
-
import bitsandbytes as bnb
|
3 |
-
import transformers
|
4 |
-
import bs4
|
5 |
-
import pandas as pd
|
6 |
-
import re
|
7 |
-
import streamlit as st
|
8 |
-
import pandas as pd
|
9 |
import os
|
10 |
-
|
11 |
from dotenv import load_dotenv
|
12 |
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
13 |
from langchain_core.output_parsers import StrOutputParser
|
14 |
-
from
|
15 |
-
|
16 |
-
from langchain_community.document_loaders import YoutubeLoader
|
17 |
-
from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader
|
18 |
-
from langchain_community.vectorstores.utils import filter_complex_metadata
|
19 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
20 |
-
from langchain_community.vectorstores import FAISS
|
21 |
-
from langchain.chains import RetrievalQA
|
22 |
-
from langchain.llms import HuggingFacePipeline
|
23 |
-
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
24 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
25 |
-
|
26 |
-
from huggingface_hub import login
|
27 |
# Load environment variables from .env file
|
28 |
load_dotenv()
|
29 |
|
30 |
# Get the API token from environment variable
|
31 |
api_token = os.getenv("API_TOKEN")
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
return_tensors="pt",
|
41 |
-
padding=True,
|
42 |
-
truncation=True,
|
43 |
-
trust_remote_code=True,
|
44 |
-
)
|
45 |
-
tokenizer.pad_token = tokenizer.eos_token
|
46 |
-
tokenizer.padding_side = "right"
|
47 |
-
|
48 |
-
model = AutoModelForCausalLM.from_pretrained(
|
49 |
-
model_id,
|
50 |
-
quantization_config=quantization_config,
|
51 |
-
device_map="auto",
|
52 |
-
low_cpu_mem_usage=True,
|
53 |
-
pad_token_id=0,
|
54 |
-
)
|
55 |
-
model.config.use_cache = False
|
56 |
-
|
57 |
-
# Create a text generation pipeline with specific settings
|
58 |
-
pipe = transformers.pipeline(
|
59 |
-
task="text-generation",
|
60 |
-
model=model,
|
61 |
-
tokenizer=tokenizer,
|
62 |
-
torch_dtype=torch.float16,
|
63 |
-
device_map="auto",
|
64 |
-
temperature=0.0,
|
65 |
-
top_p=0.9,
|
66 |
-
num_return_sequences=1,
|
67 |
-
eos_token_id=tokenizer.eos_token_id,
|
68 |
-
max_length=4096,
|
69 |
-
truncation=True,
|
70 |
-
)
|
71 |
-
|
72 |
-
chat_model = HuggingFacePipeline(pipeline=pipe)
|
73 |
|
|
|
74 |
template = """
|
75 |
-
You are a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
Summarize and provide the main insights.
|
79 |
-
Be as detailed as possible, but don't make up any information that’s not from the context.
|
80 |
-
If you don't know an answer, say you don't know.
|
81 |
-
Let's think step by step.
|
82 |
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
Chat history:
|
87 |
{chat_history}
|
@@ -90,21 +49,32 @@ User question:
|
|
90 |
{user_question}
|
91 |
"""
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
96 |
-
# Define the regular expression pattern for YouTube URLs
|
97 |
-
youtube_regex = (r'(https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)[^ \n]+)')
|
98 |
-
# Use re.findall() to find all matches in the text
|
99 |
-
matches = re.findall(youtube_regex, text)
|
100 |
-
return str(' '.join(matches))
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
# Initialize session state
|
104 |
if "chat_history" not in st.session_state:
|
105 |
-
st.session_state.chat_history = [
|
106 |
-
|
107 |
-
|
|
|
108 |
# Display chat history
|
109 |
for message in st.session_state.chat_history:
|
110 |
if isinstance(message, AIMessage):
|
@@ -114,7 +84,6 @@ for message in st.session_state.chat_history:
|
|
114 |
with st.chat_message("Human"):
|
115 |
st.write(message.content)
|
116 |
|
117 |
-
|
118 |
# User input
|
119 |
user_query = st.chat_input("Type your message here...")
|
120 |
if user_query is not None and user_query != "":
|
@@ -122,51 +91,13 @@ if user_query is not None and user_query != "":
|
|
122 |
|
123 |
with st.chat_message("Human"):
|
124 |
st.markdown(user_query)
|
125 |
-
|
126 |
-
loader = YoutubeLoader.from_youtube_url(
|
127 |
-
find_youtube_links(user_query),
|
128 |
-
add_video_info=False,
|
129 |
-
language=["en", "vi"],
|
130 |
-
translation="en",
|
131 |
-
)
|
132 |
-
docs = loader.load()
|
133 |
-
# Convert the loaded documents to a list of dictionaries
|
134 |
-
data_list = [
|
135 |
-
{
|
136 |
-
"source": doc.metadata['source'],
|
137 |
-
"page_content": doc.page_content
|
138 |
-
}
|
139 |
-
for doc in docs
|
140 |
-
]
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
# Split the document into chunks with a specified chunk size
|
148 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
|
149 |
-
all_splits = text_splitter.split_documents(content)
|
150 |
-
|
151 |
-
# Initialize the embedding model
|
152 |
-
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
153 |
-
|
154 |
-
# Store the document into a vector store with a specific embedding model
|
155 |
-
vectorstore = FAISS.from_documents(all_splits, embedding_model)
|
156 |
-
reviews_retriever = vectorstore.as_retriever()
|
157 |
-
|
158 |
-
# Function to get a response from the model
|
159 |
-
def get_response(user_query, chat_history):
|
160 |
-
chain = prompt_template | chat_model | StrOutputParser()
|
161 |
-
response = chain.invoke({
|
162 |
-
"user_question": user_query,
|
163 |
-
"chat_history": chat_history,
|
164 |
-
})
|
165 |
-
return response
|
166 |
-
|
167 |
-
response = get_response(reviews_retriever, st.session_state.chat_history)
|
168 |
-
|
169 |
with st.chat_message("AI"):
|
170 |
st.write(response)
|
171 |
|
172 |
-
st.session_state.chat_history.append(AIMessage(content=response))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import streamlit as st
|
3 |
from dotenv import load_dotenv
|
4 |
from langchain_core.messages import AIMessage, HumanMessage
|
5 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Load environment variables from .env file
|
10 |
load_dotenv()
|
11 |
|
12 |
# Get the API token from environment variable
|
13 |
api_token = os.getenv("API_TOKEN")
|
14 |
|
15 |
+
# Define the repository ID and task
|
16 |
+
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
17 |
+
task = "text-generation"
|
18 |
+
|
19 |
+
# App config
|
20 |
+
st.set_page_config(page_title="GOAHEAD.AI",page_icon= "🌍")
|
21 |
+
st.title("GOAHEAD.AI ✈️")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Define the template outside the function
|
24 |
template = """
|
25 |
+
You are a travel assistant chatbot your name is GOAHEAD.AI designed to help users plan their trips and provide travel-related information. Here are some scenarios you should be able to handle:
|
26 |
+
|
27 |
+
1. Booking Flights: Assist users with booking flights to their desired destinations. Ask for departure city, destination city, travel dates, and any specific preferences (e.g., direct flights, airline preferences). Check available airlines and book the tickets accordingly.
|
28 |
+
|
29 |
+
2. Booking Hotels: Help users find and book accommodations. Inquire about city or region, check-in/check-out dates, number of guests, and accommodation preferences (e.g., budget, amenities).
|
30 |
+
|
31 |
+
3. Booking Rental Cars: Facilitate the booking of rental cars for travel convenience. Gather details such as pickup/drop-off locations, dates, car preferences (e.g., size, type), and any additional requirements.
|
32 |
+
|
33 |
+
4. Destination Information: Provide information about popular travel destinations. Offer insights on attractions, local cuisine, cultural highlights, weather conditions, and best times to visit.
|
34 |
+
|
35 |
+
5. Travel Tips: Offer practical travel tips and advice. Topics may include packing essentials, visa requirements, currency exchange, local customs, and safety tips.
|
36 |
+
|
37 |
+
6. Weather Updates: Give current weather updates for specific destinations or regions. Include temperature forecasts, precipitation chances, and any weather advisories.
|
38 |
|
39 |
+
7. Local Attractions: Suggest local attractions and points of interest based on the user's destination. Highlight must-see landmarks, museums, parks, and recreational activities.
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
8. Customer Service: Address customer service inquiries and provide assistance with travel-related issues. Handle queries about bookings, cancellations, refunds, and general support.
|
42 |
+
|
43 |
+
Please ensure responses are informative, accurate, and tailored to the user's queries and preferences. Use natural language to engage users and provide a seamless experience throughout their travel planning journey.
|
44 |
|
45 |
Chat history:
|
46 |
{chat_history}
|
|
|
49 |
{user_question}
|
50 |
"""
|
51 |
|
52 |
+
prompt = ChatPromptTemplate.from_template(template)
|
53 |
+
|
54 |
+
# Function to get a response from the model
|
55 |
+
def get_response(user_query, chat_history):
|
56 |
+
# Initialize the Hugging Face Endpoint
|
57 |
+
llm = HuggingFaceEndpoint(
|
58 |
+
huggingfacehub_api_token=api_token,
|
59 |
+
repo_id=repo_id,
|
60 |
+
task=task
|
61 |
+
)
|
62 |
|
63 |
+
chain = prompt | llm | StrOutputParser()
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
response = chain.invoke({
|
66 |
+
"chat_history": chat_history,
|
67 |
+
"user_question": user_query,
|
68 |
+
})
|
69 |
+
|
70 |
+
return response
|
71 |
|
72 |
# Initialize session state
|
73 |
if "chat_history" not in st.session_state:
|
74 |
+
st.session_state.chat_history = [
|
75 |
+
AIMessage(content="Hello, how can I help you?"),
|
76 |
+
]
|
77 |
+
|
78 |
# Display chat history
|
79 |
for message in st.session_state.chat_history:
|
80 |
if isinstance(message, AIMessage):
|
|
|
84 |
with st.chat_message("Human"):
|
85 |
st.write(message.content)
|
86 |
|
|
|
87 |
# User input
|
88 |
user_query = st.chat_input("Type your message here...")
|
89 |
if user_query is not None and user_query != "":
|
|
|
91 |
|
92 |
with st.chat_message("Human"):
|
93 |
st.markdown(user_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
response = get_response(user_query, st.session_state.chat_history)
|
96 |
+
|
97 |
+
# Remove any unwanted prefixes from the response
|
98 |
+
response = response.replace("AI response:", "").replace("chat response:", "").replace("bot response:", "").strip()
|
99 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with st.chat_message("AI"):
|
101 |
st.write(response)
|
102 |
|
103 |
+
st.session_state.chat_history.append(AIMessage(content=response))
|