hienbm commited on
Commit
67fe8d2
·
verified ·
1 Parent(s): c0d9ec7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -103
app.py CHANGED
@@ -1,103 +1,198 @@
1
- import os
2
- import streamlit as st
3
- from dotenv import load_dotenv
4
- from langchain_core.messages import AIMessage, HumanMessage
5
- from langchain_community.llms import HuggingFaceEndpoint
6
- from langchain_core.output_parsers import StrOutputParser
7
- from langchain_core.prompts import ChatPromptTemplate
8
-
9
- # Load environment variables from .env file
10
- load_dotenv()
11
-
12
- # Get the API token from environment variable
13
- api_token = os.getenv("API_TOKEN")
14
-
15
- # Define the repository ID and task
16
- repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
17
- task = "text-generation"
18
-
19
- # App config
20
- st.set_page_config(page_title="GOAHEAD.AI",page_icon= "🌍")
21
- st.title("GOAHEAD.AI ✈️")
22
-
23
- # Define the template outside the function
24
- template = """
25
- You are a travel assistant chatbot your name is GOAHEAD.AI designed to help users plan their trips and provide travel-related information. Here are some scenarios you should be able to handle:
26
-
27
- 1. Booking Flights: Assist users with booking flights to their desired destinations. Ask for departure city, destination city, travel dates, and any specific preferences (e.g., direct flights, airline preferences). Check available airlines and book the tickets accordingly.
28
-
29
- 2. Booking Hotels: Help users find and book accommodations. Inquire about city or region, check-in/check-out dates, number of guests, and accommodation preferences (e.g., budget, amenities).
30
-
31
- 3. Booking Rental Cars: Facilitate the booking of rental cars for travel convenience. Gather details such as pickup/drop-off locations, dates, car preferences (e.g., size, type), and any additional requirements.
32
-
33
- 4. Destination Information: Provide information about popular travel destinations. Offer insights on attractions, local cuisine, cultural highlights, weather conditions, and best times to visit.
34
-
35
- 5. Travel Tips: Offer practical travel tips and advice. Topics may include packing essentials, visa requirements, currency exchange, local customs, and safety tips.
36
-
37
- 6. Weather Updates: Give current weather updates for specific destinations or regions. Include temperature forecasts, precipitation chances, and any weather advisories.
38
-
39
- 7. Local Attractions: Suggest local attractions and points of interest based on the user's destination. Highlight must-see landmarks, museums, parks, and recreational activities.
40
-
41
- 8. Customer Service: Address customer service inquiries and provide assistance with travel-related issues. Handle queries about bookings, cancellations, refunds, and general support.
42
-
43
- Please ensure responses are informative, accurate, and tailored to the user's queries and preferences. Use natural language to engage users and provide a seamless experience throughout their travel planning journey.
44
-
45
- Chat history:
46
- {chat_history}
47
-
48
- User question:
49
- {user_question}
50
- """
51
-
52
- prompt = ChatPromptTemplate.from_template(template)
53
-
54
- # Function to get a response from the model
55
- def get_response(user_query, chat_history):
56
- # Initialize the Hugging Face Endpoint
57
- llm = HuggingFaceEndpoint(
58
- huggingfacehub_api_token=api_token,
59
- repo_id=repo_id,
60
- task=task
61
- )
62
-
63
- chain = prompt | llm | StrOutputParser()
64
-
65
- response = chain.invoke({
66
- "chat_history": chat_history,
67
- "user_question": user_query,
68
- })
69
-
70
- return response
71
-
72
- # Initialize session state
73
- if "chat_history" not in st.session_state:
74
- st.session_state.chat_history = [
75
- AIMessage(content="Hello, how can I help you?"),
76
- ]
77
-
78
- # Display chat history
79
- for message in st.session_state.chat_history:
80
- if isinstance(message, AIMessage):
81
- with st.chat_message("AI"):
82
- st.write(message.content)
83
- elif isinstance(message, HumanMessage):
84
- with st.chat_message("Human"):
85
- st.write(message.content)
86
-
87
- # User input
88
- user_query = st.chat_input("Type your message here...")
89
- if user_query is not None and user_query != "":
90
- st.session_state.chat_history.append(HumanMessage(content=user_query))
91
-
92
- with st.chat_message("Human"):
93
- st.markdown(user_query)
94
-
95
- response = get_response(user_query, st.session_state.chat_history)
96
-
97
- # Remove any unwanted prefixes from the response
98
- response = response.replace("AI response:", "").replace("chat response:", "").replace("bot response:", "").strip()
99
-
100
- with st.chat_message("AI"):
101
- st.write(response)
102
-
103
- st.session_state.chat_history.append(AIMessage(content=response))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled68.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1h4tpXH6r9B2VZLVwksIkuuVpcrXTUnuJ
8
+ """
9
+
10
+ import torch
11
+ import bitsandbytes as bnb
12
+ import transformers
13
+ import re
14
+ import pandas as pd
15
+ import os
16
+ import streamlit as st
17
+
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
19
+ from langchain.llms import HuggingFacePipeline
20
+ from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from langchain_community.document_loaders import YoutubeLoader, DataFrameLoader
23
+ from langchain_community.vectorstores.utils import filter_complex_metadata
24
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
25
+ from langchain_community.embeddings import HuggingFaceEmbeddings
26
+ from langchain_community.vectorstores import FAISS
27
+ from langchain.schema.runnable import RunnablePassthrough
28
+ from langchain_core.messages import AIMessage, HumanMessage
29
+ from dotenv import load_dotenv
30
+
31
+ # Get the API token from environment variable
32
+ api_token = os.getenv("API_TOKEN")
33
+
34
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:15000"
35
+
36
+ model_id = "google/gemma-2-9b-it"
37
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(
40
+ model_id,
41
+ return_tensors="pt",
42
+ padding=True,
43
+ truncation=True,
44
+ trust_remote_code=True,
45
+ )
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+ tokenizer.padding_side = "right"
48
+
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ quantization_config=quantization_config,
52
+ device_map="auto",
53
+ low_cpu_mem_usage=True,
54
+ pad_token_id=0,
55
+ )
56
+ model.config.use_cache = False
57
+
58
+ # Create a text generation pipeline with specific settings
59
+ pipe = transformers.pipeline(
60
+ task="text-generation",
61
+ model=model,
62
+ tokenizer=tokenizer,
63
+ torch_dtype=torch.float16,
64
+ device_map="auto",
65
+ # do_sample=True,
66
+ # top_k=10,
67
+ temperature=0.0,
68
+ top_p=0.9,
69
+ num_return_sequences=1,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ max_length=4096,
72
+ truncation=True,
73
+ )
74
+
75
+ chat_model = HuggingFacePipeline(pipeline=pipe)
76
+
77
+
78
+ template = """
79
+ You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.
80
+
81
+ **ALWAYS**
82
+ Summarize and provide the main insights.
83
+ Be as detailed as possible, but don't make up any information that’s not from the context.
84
+ If you don't know an answer, say you don't know.
85
+ Let's think step by step.
86
+
87
+ Please ensure responses are informative, accurate, and tailored to the user's queries and preferences.
88
+ Use natural language to engage users and provide readable content throughout your response.
89
+
90
+ {context}
91
+ """
92
+
93
+ review_system_prompt = SystemMessagePromptTemplate(
94
+ prompt=PromptTemplate(
95
+ input_variables=["context"],
96
+ template=template,
97
+ )
98
+ )
99
+
100
+ review_human_prompt = HumanMessagePromptTemplate(
101
+ prompt=PromptTemplate(
102
+ input_variables=["question"],
103
+ template="{question}",
104
+ )
105
+ )
106
+ messages = [review_system_prompt, review_human_prompt]
107
+
108
+ review_prompt_template = ChatPromptTemplate(
109
+ input_variables=["context", "question"],
110
+ messages=messages,
111
+ )
112
+
113
+
114
+ def find_youtube_links(text):
115
+ # Define the regular expression pattern for YouTube URLs
116
+ youtube_regex = (r'(https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)[^ \n]+)')
117
+ # Use re.findall() to find all matches in the text
118
+ matches = re.findall(youtube_regex, text)
119
+ return str(' '.join(matches))
120
+
121
+
122
+ # Function to get a response from the model
123
+ def get_response(user_query):
124
+ review_chain = (
125
+ {"context": reviews_retriever, "question": RunnablePassthrough()}
126
+ | review_prompt_template
127
+ | chat_model
128
+ | StrOutputParser()
129
+ )
130
+ response = review_chain.invoke(user_query)
131
+ return response
132
+
133
+ # App config
134
+ st.set_page_config(page_title="GOAHEAD.VN", page_icon="🌍")
135
+ st.title("Summary and provide insights from youtube news.")
136
+
137
+ # Initialize session state
138
+ if "chat_history" not in st.session_state:
139
+ st.session_state.chat_history = [
140
+ AIMessage(content="Hello, how can I help you?"),
141
+ ]
142
+
143
+ # Display chat history
144
+ for message in st.session_state.chat_history:
145
+ if isinstance(message, AIMessage):
146
+ with st.chat_message("AI"):
147
+ st.write(message.content)
148
+ elif isinstance(message, HumanMessage):
149
+ with st.chat_message("Human"):
150
+ st.write(message.content)
151
+
152
+ # User input
153
+ user_query = st.chat_input("Type your message here...")
154
+
155
+ if user_query is not None and find_youtube_links(user_query) != "":
156
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
157
+
158
+ with st.chat_message("Human"):
159
+ st.markdown(user_query)
160
+
161
+ loader = YoutubeLoader.from_youtube_url(
162
+ find_youtube_links(url),
163
+ add_video_info=False,
164
+ language=["en", "vi"],
165
+ translation="en",
166
+ )
167
+ docs = loader.load()
168
+ # Convert the loaded documents to a list of dictionaries
169
+ data_list = [
170
+ {
171
+ "source": doc.metadata['source'],
172
+ "page_content": doc.page_content
173
+ }
174
+ for doc in docs
175
+ ]
176
+ df = pd.DataFrame(data_list)
177
+ loader = DataFrameLoader(df, page_content_column='page_content')
178
+ content = loader.load()
179
+ content = filter_complex_metadata(content)
180
+
181
+ # Split the document into chunks with a specified chunk size
182
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
183
+ all_splits = text_splitter.split_documents(content)
184
+
185
+ # Initialize the embedding model
186
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
187
+
188
+ # Store the document into a vector store with a specific embedding model
189
+ vectorstore = FAISS.from_documents(all_splits, embedding_model)
190
+
191
+ reviews_retriever = vectorstore.as_retriever()
192
+
193
+ response = get_response("Help me summary and provide main insights.")
194
+
195
+ with st.chat_message("AI"):
196
+ st.write(response)
197
+
198
+ st.session_state.chat_history.append(AIMessage(content=response))