hienbm commited on
Commit
bd0fa4a
·
verified ·
1 Parent(s): 554c667

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import bitsandbytes as bnb
3
+ import transformers
4
+ import bs4
5
+ import pandas as pd
6
+ import re
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import os
10
+
11
+ from dotenv import load_dotenv
12
+ from langchain_core.messages import AIMessage, HumanMessage
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from IPython.display import clear_output
15
+ from langchain.schema.runnable import RunnablePassthrough
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain_community.document_loaders import YoutubeLoader
18
+ from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader
19
+ from langchain_community.vectorstores.utils import filter_complex_metadata
20
+ from langchain_community.embeddings import HuggingFaceEmbeddings
21
+ from langchain_community.vectorstores import FAISS
22
+ from langchain.chains import RetrievalQA
23
+ from langchain.llms import HuggingFacePipeline
24
+ from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
25
+
26
+ from IPython.display import display, Markdown, clear_output
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
28
+
29
+ from huggingface_hub import login
30
+ # Load environment variables from .env file
31
+ load_dotenv()
32
+
33
+ # Get the API token from environment variable
34
+ api_token = os.getenv("API_TOKEN")
35
+
36
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:15000"
37
+
38
+ model_id = "google/gemma-2-9b-it"
39
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ model_id,
43
+ return_tensors="pt",
44
+ padding=True,
45
+ truncation=True,
46
+ trust_remote_code=True,
47
+ )
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.padding_side = "right"
50
+
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_id,
53
+ quantization_config=quantization_config,
54
+ device_map="auto",
55
+ low_cpu_mem_usage=True,
56
+ pad_token_id=0,
57
+ )
58
+ model.config.use_cache = False
59
+
60
+ # Create a text generation pipeline with specific settings
61
+ pipe = transformers.pipeline(
62
+ task="text-generation",
63
+ model=model,
64
+ tokenizer=tokenizer,
65
+ torch_dtype=torch.float16,
66
+ device_map="auto",
67
+ temperature=0.0,
68
+ top_p=0.9,
69
+ num_return_sequences=1,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ max_length=4096,
72
+ truncation=True,
73
+ )
74
+
75
+ chat_model = HuggingFacePipeline(pipeline=pipe)
76
+
77
+ template = """
78
+ You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.
79
+
80
+ **ALWAYS**
81
+ Summarize and provide the main insights.
82
+ Be as detailed as possible, but don't make up any information that’s not from the context.
83
+ If you don't know an answer, say you don't know.
84
+ Let's think step by step.
85
+
86
+ Please ensure responses are informative, accurate, and tailored to the user's queries and preferences.
87
+ Use natural language to engage users and provide readable content throughout your response.
88
+
89
+ Chat history:
90
+ {chat_history}
91
+
92
+ User question:
93
+ {user_question}
94
+ """
95
+
96
+ prompt_template = ChatPromptTemplate.from_template(template)
97
+
98
+ def find_youtube_links(text):
99
+ # Define the regular expression pattern for YouTube URLs
100
+ youtube_regex = (r'(https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)[^ \n]+)')
101
+ # Use re.findall() to find all matches in the text
102
+ matches = re.findall(youtube_regex, text)
103
+ return str(' '.join(matches))
104
+
105
+
106
+ # Initialize session state
107
+ if "chat_history" not in st.session_state:
108
+ st.session_state.chat_history = [AIMessage(content="Hello, how can I help you?")]
109
+
110
+
111
+ # Display chat history
112
+ for message in st.session_state.chat_history:
113
+ if isinstance(message, AIMessage):
114
+ with st.chat_message("AI"):
115
+ st.write(message.content)
116
+ elif isinstance(message, HumanMessage):
117
+ with st.chat_message("Human"):
118
+ st.write(message.content)
119
+
120
+
121
+ # User input
122
+ user_query = st.chat_input("Type your message here...")
123
+ if user_query is not None and user_query != "":
124
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
125
+
126
+ with st.chat_message("Human"):
127
+ st.markdown(user_query)
128
+
129
+ loader = YoutubeLoader.from_youtube_url(
130
+ find_youtube_links(user_query),
131
+ add_video_info=False,
132
+ language=["en", "vi"],
133
+ translation="en",
134
+ )
135
+ docs = loader.load()
136
+ # Convert the loaded documents to a list of dictionaries
137
+ data_list = [
138
+ {
139
+ "source": doc.metadata['source'],
140
+ "page_content": doc.page_content
141
+ }
142
+ for doc in docs
143
+ ]
144
+
145
+ df = pd.DataFrame(data_list)
146
+ loader = DataFrameLoader(df, page_content_column='page_content')
147
+ content = loader.load()
148
+ # reviews = filter_complex_metadata(reviews)
149
+
150
+ # Split the document into chunks with a specified chunk size
151
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
152
+ all_splits = text_splitter.split_documents(content)
153
+
154
+ # Initialize the embedding model
155
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
156
+
157
+ # Store the document into a vector store with a specific embedding model
158
+ vectorstore = FAISS.from_documents(all_splits, embedding_model)
159
+ reviews_retriever = vectorstore.as_retriever()
160
+
161
+ # Function to get a response from the model
162
+ def get_response(user_query, chat_history):
163
+ chain = prompt_template | chat_model | StrOutputParser()
164
+ response = chain.invoke({
165
+ "user_question": user_query,
166
+ "chat_history": chat_history,
167
+ })
168
+ return response
169
+
170
+ response = get_response(reviews_retriever, st.session_state.chat_history)
171
+
172
+ with st.chat_message("AI"):
173
+ st.write(response)
174
+
175
+ st.session_state.chat_history.append(AIMessage(content=response))