Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings | |
from langchain.vectorstores.faiss import FAISS | |
from langchain.chains import ChatVectorDBChain | |
from huggingface_hub import snapshot_download | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
AIMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.schema import ( | |
AIMessage, | |
HumanMessage, | |
SystemMessage | |
) | |
st.set_page_config(page_title="CFA Level 1", page_icon="π") | |
#### sidebar section 1 #### | |
with st.sidebar: | |
book = st.radio("Choose an Embedding Model: ", | |
["Instruct", "Sbert"] | |
) | |
#load embedding models | |
def load_embedding_models(model): | |
if model == 'Sbert': | |
model_sbert = "sentence-transformers/all-mpnet-base-v2" | |
emb = HuggingFaceEmbeddings(model_name=model_sbert) | |
elif model == 'Instruct': | |
embed_instruction = "Represent the financial paragraph for document retrieval: " | |
query_instruction = "Represent the question for retrieving supporting documents: " | |
model_instr = "hkunlp/instructor-large" | |
emb = HuggingFaceInstructEmbeddings(model_name=model_instr, | |
embed_instruction=embed_instruction, | |
query_instruction=query_instruction) | |
return emb | |
st.title(f"Talk to CFA Level 1 Book") | |
st.markdown(f"#### Have a conversation with the CFA Curriculum by the CFA Institute π") | |
embeddings = load_embedding_models(book) | |
##### functionss #### | |
def load_vectorstore(_embeddings): | |
# download from hugging face | |
cache_dir="cfa_level_1_cache" | |
snapshot_download(repo_id="nickmuchi/CFA_Level_1_Text_Embeddings", | |
repo_type="dataset", | |
revision="main", | |
allow_patterns="CFA_Level_1/*", | |
cache_dir=cache_dir, | |
) | |
target_dir = "CFA_Level_1" | |
# Walk through the directory tree recursively | |
for root, dirs, files in os.walk(cache_dir): | |
# Check if the target directory is in the list of directories | |
if target_dir in dirs: | |
# Get the full path of the target directory | |
target_path = os.path.join(root, target_dir) | |
print(target_path) | |
# load faiss | |
docsearch = FAISS.load_local(folder_path=target_path, embeddings=_embeddings) | |
return docsearch | |
def load_prompt(): | |
system_template="""You are an expert in finance, economics, investing, ethics, derivatives and markets. | |
Use the following pieces of context to answer the users question. If you don't know the answer, | |
just say that you don't know, don't try to make up an answer. Provide a source reference. | |
ALWAYS return a "sources" part in your answer. | |
The "sources" part should be a reference to the source of the documents from which you got your answer. List all sources used | |
The output should be a markdown code snippet formatted in the following schema: | |
```json | |
{{ | |
answer: is foo | |
sources: xyz | |
}} | |
``` | |
Begin! | |
---------------- | |
{context}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}") | |
] | |
prompt = ChatPromptTemplate.from_messages(messages) | |
return prompt | |
def load_chain(): | |
llm = ChatOpenAI(temperature=0) | |
qa = ChatVectorDBChain.from_llm(llm, | |
load_vectorstore(embeddings), | |
qa_prompt=load_prompt(), | |
return_source_documents=True) | |
return qa | |
def get_answer(question): | |
chain = load_chain() | |
result = chain({"query": question}) | |
answer = result["result"] | |
# pages | |
unique_sources = set() | |
for item in result['source_documents']: | |
unique_sources.add(item.metadata['page']) | |
unique_pages = "" | |
for item in unique_sources: | |
unique_pages += str(item) + ", " | |
# will look like 1, 2, 3, | |
pages = unique_pages[:-2] # removes the last comma and space | |
# source text | |
full_source = "" | |
for item in result['source_documents']: | |
full_source += f"- **Page: {item.metadata['page']}**" + "\n" + item.page_content + "\n\n" | |
# will look like: | |
# - Page: {number} | |
# {extracted text from book} | |
extract = full_source | |
return answer, pages, extract | |
##### sidebar section 2 #### | |
api_key = os.environ["OPENAI_API_KEY"] | |
##### main #### | |
user_input = st.text_input("Your question", "What is an MBS and who are the main issuer and investors of the MBS market?", key="input") | |
col1, col2 = st.columns([10, 1]) | |
# show question | |
col1.write(f"**You:** {user_input}") | |
# ask button to the right of the displayed question | |
ask = col2.button("Ask", type="primary") | |
if ask: | |
with st.spinner("this can take about a minute for your first question because some models have to be downloaded π₯Ίππ»ππ»"): | |
try: | |
answer, pages, extract = get_answer(question=user_input) | |
except Exception as e: | |
st.write(f"Error with Download: {e}") | |
st.stop() | |
st.write(f"{answer}") | |
# sources | |
with st.expander(label = f"From pages: {pages}", expanded = False): | |
st.markdown(extract) |