|
import torch
|
|
import bitsandbytes as bnb
|
|
import transformers
|
|
import bs4
|
|
import pandas as pd
|
|
import re
|
|
import streamlit as st
|
|
import pandas as pd
|
|
import os
|
|
|
|
from dotenv import load_dotenv
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain.schema.runnable import RunnablePassthrough
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.document_loaders import YoutubeLoader
|
|
from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader
|
|
from langchain_community.vectorstores.utils import filter_complex_metadata
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.llms import HuggingFacePipeline
|
|
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
|
from huggingface_hub import login
|
|
|
|
load_dotenv()
|
|
|
|
|
|
api_token = os.getenv("API_TOKEN")
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:15000"
|
|
|
|
model_id = "google/gemma-2-9b-it"
|
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
trust_remote_code=True,
|
|
)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "right"
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
quantization_config=quantization_config,
|
|
device_map="auto",
|
|
low_cpu_mem_usage=True,
|
|
pad_token_id=0,
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
|
|
pipe = transformers.pipeline(
|
|
task="text-generation",
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
temperature=0.0,
|
|
top_p=0.9,
|
|
num_return_sequences=1,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
max_length=4096,
|
|
truncation=True,
|
|
)
|
|
|
|
chat_model = HuggingFacePipeline(pipeline=pipe)
|
|
|
|
template = """
|
|
You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.
|
|
|
|
**ALWAYS**
|
|
Summarize and provide the main insights.
|
|
Be as detailed as possible, but don't make up any information that’s not from the context.
|
|
If you don't know an answer, say you don't know.
|
|
Let's think step by step.
|
|
|
|
Please ensure responses are informative, accurate, and tailored to the user's queries and preferences.
|
|
Use natural language to engage users and provide readable content throughout your response.
|
|
|
|
Chat history:
|
|
{chat_history}
|
|
|
|
User question:
|
|
{user_question}
|
|
"""
|
|
|
|
prompt_template = ChatPromptTemplate.from_template(template)
|
|
|
|
def find_youtube_links(text):
|
|
|
|
youtube_regex = (r'(https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)[^ \n]+)')
|
|
|
|
matches = re.findall(youtube_regex, text)
|
|
return str(' '.join(matches))
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state.chat_history = [AIMessage(content="Hello, how can I help you?")]
|
|
|
|
|
|
|
|
for message in st.session_state.chat_history:
|
|
if isinstance(message, AIMessage):
|
|
with st.chat_message("AI"):
|
|
st.write(message.content)
|
|
elif isinstance(message, HumanMessage):
|
|
with st.chat_message("Human"):
|
|
st.write(message.content)
|
|
|
|
|
|
|
|
user_query = st.chat_input("Type your message here...")
|
|
if user_query is not None and user_query != "":
|
|
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
|
|
|
with st.chat_message("Human"):
|
|
st.markdown(user_query)
|
|
|
|
loader = YoutubeLoader.from_youtube_url(
|
|
find_youtube_links(user_query),
|
|
add_video_info=False,
|
|
language=["en", "vi"],
|
|
translation="en",
|
|
)
|
|
docs = loader.load()
|
|
|
|
data_list = [
|
|
{
|
|
"source": doc.metadata['source'],
|
|
"page_content": doc.page_content
|
|
}
|
|
for doc in docs
|
|
]
|
|
|
|
df = pd.DataFrame(data_list)
|
|
loader = DataFrameLoader(df, page_content_column='page_content')
|
|
content = loader.load()
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
|
|
all_splits = text_splitter.split_documents(content)
|
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
|
|
|
|
|
vectorstore = FAISS.from_documents(all_splits, embedding_model)
|
|
reviews_retriever = vectorstore.as_retriever()
|
|
|
|
|
|
def get_response(user_query, chat_history):
|
|
chain = prompt_template | chat_model | StrOutputParser()
|
|
response = chain.invoke({
|
|
"user_question": user_query,
|
|
"chat_history": chat_history,
|
|
})
|
|
return response
|
|
|
|
response = get_response(reviews_retriever, st.session_state.chat_history)
|
|
|
|
with st.chat_message("AI"):
|
|
st.write(response)
|
|
|
|
st.session_state.chat_history.append(AIMessage(content=response)) |