|
import os |
|
import gc |
|
import tempfile |
|
import uuid |
|
import pandas as pd |
|
import streamlit as st |
|
from llama_index.core import Settings |
|
from llama_index.llms.cerebras import Cerebras |
|
from llama_index.core import PromptTemplate |
|
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat |
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader |
|
from llama_index.readers.docling import DoclingReader |
|
from llama_index.core.node_parser import MarkdownNodeParser |
|
from llama_index.core.llms import ChatMessage |
|
|
|
if "id" not in st.session_state: |
|
st.session_state.id = uuid.uuid4() |
|
st.session_state.file_cache = {} |
|
|
|
session_id = st.session_state.id |
|
client = None |
|
|
|
|
|
def load_llm(): |
|
|
|
api_key = os.getenv("CEREBRAS_API_KEY") |
|
llm = Cerebras(model="llama-3.3-70b", api_key=api_key) |
|
return llm |
|
|
|
def reset_chat(): |
|
st.session_state.messages = [] |
|
st.session_state.context = None |
|
gc.collect() |
|
|
|
def display_excel(file): |
|
st.markdown("### Excel Preview") |
|
|
|
df = pd.read_excel(file) |
|
|
|
st.dataframe(df) |
|
|
|
with st.sidebar: |
|
st.header(f"Add your documents!") |
|
|
|
uploaded_file = st.file_uploader("Choose your `.xlsx` file", type=["xlsx", "xls"]) |
|
|
|
if uploaded_file: |
|
try: |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
file_path = os.path.join(temp_dir, uploaded_file.name) |
|
|
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getvalue()) |
|
|
|
file_key = f"{session_id}-{uploaded_file.name}" |
|
st.write("Indexing your document...") |
|
|
|
if file_key not in st.session_state.get('file_cache', {}): |
|
if os.path.exists(temp_dir): |
|
reader = DoclingReader() |
|
loader = SimpleDirectoryReader( |
|
input_dir=temp_dir, |
|
file_extractor={".xlsx": reader}, |
|
) |
|
else: |
|
st.error('Could not find the file you uploaded, please check again...') |
|
st.stop() |
|
|
|
docs = loader.load_data() |
|
|
|
|
|
llm = load_llm() |
|
|
|
mixedbread_api_key = os.getenv("MXBAI_API_KEY") |
|
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1") |
|
|
|
|
|
Settings.embed_model = embed_model |
|
node_parser = MarkdownNodeParser() |
|
index = VectorStoreIndex.from_documents(documents=docs, transformations=[node_parser], show_progress=True) |
|
|
|
|
|
Settings.llm = llm |
|
query_engine = index.as_query_engine(streaming=True) |
|
|
|
|
|
qa_prompt_tmpl_str = ( |
|
"Context information is below.\n" |
|
"---------------------\n" |
|
"{context_str}\n" |
|
"---------------------\n" |
|
"Given the context information above I want you to think step by step to answer the query in a highly precise and crisp manner focused on the final answer, in case you don't know the answer say 'I don't know!'.\n" |
|
"Query: {query_str}\n" |
|
"Answer: " |
|
) |
|
qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str) |
|
|
|
query_engine.update_prompts( |
|
{"response_synthesizer:text_qa_template": qa_prompt_tmpl} |
|
) |
|
|
|
st.session_state.file_cache[file_key] = query_engine |
|
else: |
|
query_engine = st.session_state.file_cache[file_key] |
|
|
|
|
|
st.success("Ready to Chat!") |
|
display_excel(uploaded_file) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
st.stop() |
|
|
|
col1, col2 = st.columns([6, 1]) |
|
|
|
with col1: |
|
st.header(f"RAG over Excel using DuckLink 🐥 & Llama-3.3-70B") |
|
|
|
with col2: |
|
st.button("Clear ↺", on_click=reset_chat) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
reset_chat() |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("What's up?"): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
|
|
streaming_response = query_engine.query(prompt) |
|
|
|
for chunk in streaming_response.response_gen: |
|
full_response += chunk |
|
message_placeholder.markdown(full_response + "▌") |
|
|
|
message_placeholder.markdown(full_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|