hienbm commited on
Commit
c0d9ec7
·
verified ·
1 Parent(s): 5dc4a5d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -125
app.py CHANGED
@@ -1,87 +1,46 @@
1
- import torch
2
- import bitsandbytes as bnb
3
- import transformers
4
- import bs4
5
- import pandas as pd
6
- import re
7
- import streamlit as st
8
- import pandas as pd
9
  import os
10
-
11
  from dotenv import load_dotenv
12
  from langchain_core.messages import AIMessage, HumanMessage
 
13
  from langchain_core.output_parsers import StrOutputParser
14
- from langchain.schema.runnable import RunnablePassthrough
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter
16
- from langchain_community.document_loaders import YoutubeLoader
17
- from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader
18
- from langchain_community.vectorstores.utils import filter_complex_metadata
19
- from langchain_community.embeddings import HuggingFaceEmbeddings
20
- from langchain_community.vectorstores import FAISS
21
- from langchain.chains import RetrievalQA
22
- from langchain.llms import HuggingFacePipeline
23
- from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
24
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
25
-
26
- from huggingface_hub import login
27
  # Load environment variables from .env file
28
  load_dotenv()
29
 
30
  # Get the API token from environment variable
31
  api_token = os.getenv("API_TOKEN")
32
 
33
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:15000"
34
-
35
- model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
36
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
37
-
38
- tokenizer = AutoTokenizer.from_pretrained(
39
- model_id,
40
- return_tensors="pt",
41
- padding=True,
42
- truncation=True,
43
- trust_remote_code=True,
44
- )
45
- tokenizer.pad_token = tokenizer.eos_token
46
- tokenizer.padding_side = "right"
47
-
48
- model = AutoModelForCausalLM.from_pretrained(
49
- model_id,
50
- quantization_config=quantization_config,
51
- device_map="auto",
52
- low_cpu_mem_usage=True,
53
- pad_token_id=0,
54
- )
55
- model.config.use_cache = False
56
-
57
- # Create a text generation pipeline with specific settings
58
- pipe = transformers.pipeline(
59
- task="text-generation",
60
- model=model,
61
- tokenizer=tokenizer,
62
- torch_dtype=torch.float16,
63
- device_map="auto",
64
- temperature=0.0,
65
- top_p=0.9,
66
- num_return_sequences=1,
67
- eos_token_id=tokenizer.eos_token_id,
68
- max_length=4096,
69
- truncation=True,
70
- )
71
-
72
- chat_model = HuggingFacePipeline(pipeline=pipe)
73
 
 
74
  template = """
75
- You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- **ALWAYS**
78
- Summarize and provide the main insights.
79
- Be as detailed as possible, but don't make up any information that’s not from the context.
80
- If you don't know an answer, say you don't know.
81
- Let's think step by step.
82
 
83
- Please ensure responses are informative, accurate, and tailored to the user's queries and preferences.
84
- Use natural language to engage users and provide readable content throughout your response.
 
85
 
86
  Chat history:
87
  {chat_history}
@@ -90,21 +49,32 @@ User question:
90
  {user_question}
91
  """
92
 
93
- prompt_template = ChatPromptTemplate.from_template(template)
 
 
 
 
 
 
 
 
 
94
 
95
- def find_youtube_links(text):
96
- # Define the regular expression pattern for YouTube URLs
97
- youtube_regex = (r'(https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)[^ \n]+)')
98
- # Use re.findall() to find all matches in the text
99
- matches = re.findall(youtube_regex, text)
100
- return str(' '.join(matches))
101
 
 
 
 
 
 
 
102
 
103
  # Initialize session state
104
  if "chat_history" not in st.session_state:
105
- st.session_state.chat_history = [AIMessage(content="Hello, how can I help you?")]
106
-
107
-
 
108
  # Display chat history
109
  for message in st.session_state.chat_history:
110
  if isinstance(message, AIMessage):
@@ -114,7 +84,6 @@ for message in st.session_state.chat_history:
114
  with st.chat_message("Human"):
115
  st.write(message.content)
116
 
117
-
118
  # User input
119
  user_query = st.chat_input("Type your message here...")
120
  if user_query is not None and user_query != "":
@@ -122,51 +91,13 @@ if user_query is not None and user_query != "":
122
 
123
  with st.chat_message("Human"):
124
  st.markdown(user_query)
125
-
126
- loader = YoutubeLoader.from_youtube_url(
127
- find_youtube_links(user_query),
128
- add_video_info=False,
129
- language=["en", "vi"],
130
- translation="en",
131
- )
132
- docs = loader.load()
133
- # Convert the loaded documents to a list of dictionaries
134
- data_list = [
135
- {
136
- "source": doc.metadata['source'],
137
- "page_content": doc.page_content
138
- }
139
- for doc in docs
140
- ]
141
 
142
- df = pd.DataFrame(data_list)
143
- loader = DataFrameLoader(df, page_content_column='page_content')
144
- content = loader.load()
145
- # reviews = filter_complex_metadata(reviews)
146
-
147
- # Split the document into chunks with a specified chunk size
148
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
149
- all_splits = text_splitter.split_documents(content)
150
-
151
- # Initialize the embedding model
152
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
153
-
154
- # Store the document into a vector store with a specific embedding model
155
- vectorstore = FAISS.from_documents(all_splits, embedding_model)
156
- reviews_retriever = vectorstore.as_retriever()
157
-
158
- # Function to get a response from the model
159
- def get_response(user_query, chat_history):
160
- chain = prompt_template | chat_model | StrOutputParser()
161
- response = chain.invoke({
162
- "user_question": user_query,
163
- "chat_history": chat_history,
164
- })
165
- return response
166
-
167
- response = get_response(reviews_retriever, st.session_state.chat_history)
168
-
169
  with st.chat_message("AI"):
170
  st.write(response)
171
 
172
- st.session_state.chat_history.append(AIMessage(content=response))
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import streamlit as st
3
  from dotenv import load_dotenv
4
  from langchain_core.messages import AIMessage, HumanMessage
5
+ from langchain_community.llms import HuggingFaceEndpoint
6
  from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+
 
 
 
 
 
 
 
 
 
 
 
9
  # Load environment variables from .env file
10
  load_dotenv()
11
 
12
  # Get the API token from environment variable
13
  api_token = os.getenv("API_TOKEN")
14
 
15
+ # Define the repository ID and task
16
+ repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
17
+ task = "text-generation"
18
+
19
+ # App config
20
+ st.set_page_config(page_title="GOAHEAD.AI",page_icon= "🌍")
21
+ st.title("GOAHEAD.AI ✈️")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Define the template outside the function
24
  template = """
25
+ You are a travel assistant chatbot your name is GOAHEAD.AI designed to help users plan their trips and provide travel-related information. Here are some scenarios you should be able to handle:
26
+
27
+ 1. Booking Flights: Assist users with booking flights to their desired destinations. Ask for departure city, destination city, travel dates, and any specific preferences (e.g., direct flights, airline preferences). Check available airlines and book the tickets accordingly.
28
+
29
+ 2. Booking Hotels: Help users find and book accommodations. Inquire about city or region, check-in/check-out dates, number of guests, and accommodation preferences (e.g., budget, amenities).
30
+
31
+ 3. Booking Rental Cars: Facilitate the booking of rental cars for travel convenience. Gather details such as pickup/drop-off locations, dates, car preferences (e.g., size, type), and any additional requirements.
32
+
33
+ 4. Destination Information: Provide information about popular travel destinations. Offer insights on attractions, local cuisine, cultural highlights, weather conditions, and best times to visit.
34
+
35
+ 5. Travel Tips: Offer practical travel tips and advice. Topics may include packing essentials, visa requirements, currency exchange, local customs, and safety tips.
36
+
37
+ 6. Weather Updates: Give current weather updates for specific destinations or regions. Include temperature forecasts, precipitation chances, and any weather advisories.
38
 
39
+ 7. Local Attractions: Suggest local attractions and points of interest based on the user's destination. Highlight must-see landmarks, museums, parks, and recreational activities.
 
 
 
 
40
 
41
+ 8. Customer Service: Address customer service inquiries and provide assistance with travel-related issues. Handle queries about bookings, cancellations, refunds, and general support.
42
+
43
+ Please ensure responses are informative, accurate, and tailored to the user's queries and preferences. Use natural language to engage users and provide a seamless experience throughout their travel planning journey.
44
 
45
  Chat history:
46
  {chat_history}
 
49
  {user_question}
50
  """
51
 
52
+ prompt = ChatPromptTemplate.from_template(template)
53
+
54
+ # Function to get a response from the model
55
+ def get_response(user_query, chat_history):
56
+ # Initialize the Hugging Face Endpoint
57
+ llm = HuggingFaceEndpoint(
58
+ huggingfacehub_api_token=api_token,
59
+ repo_id=repo_id,
60
+ task=task
61
+ )
62
 
63
+ chain = prompt | llm | StrOutputParser()
 
 
 
 
 
64
 
65
+ response = chain.invoke({
66
+ "chat_history": chat_history,
67
+ "user_question": user_query,
68
+ })
69
+
70
+ return response
71
 
72
  # Initialize session state
73
  if "chat_history" not in st.session_state:
74
+ st.session_state.chat_history = [
75
+ AIMessage(content="Hello, how can I help you?"),
76
+ ]
77
+
78
  # Display chat history
79
  for message in st.session_state.chat_history:
80
  if isinstance(message, AIMessage):
 
84
  with st.chat_message("Human"):
85
  st.write(message.content)
86
 
 
87
  # User input
88
  user_query = st.chat_input("Type your message here...")
89
  if user_query is not None and user_query != "":
 
91
 
92
  with st.chat_message("Human"):
93
  st.markdown(user_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ response = get_response(user_query, st.session_state.chat_history)
96
+
97
+ # Remove any unwanted prefixes from the response
98
+ response = response.replace("AI response:", "").replace("chat response:", "").replace("bot response:", "").strip()
99
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  with st.chat_message("AI"):
101
  st.write(response)
102
 
103
+ st.session_state.chat_history.append(AIMessage(content=response))