Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
To run: | |
- activate the virtual environment | |
- streamlit run path\to\streamlit_app.py | |
""" | |
import logging | |
import os | |
import re | |
import sys | |
import time | |
import warnings | |
import shutil | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
import openai | |
import pandas as pd | |
import streamlit as st | |
from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode | |
from streamlit_chat import message | |
from streamlit_langchain_chat.constants import * | |
from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat | |
from streamlit_langchain_chat.dataset import Dataset | |
# Configure logger | |
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
warnings.filterwarnings('ignore') | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'costs' not in st.session_state: | |
st.session_state['costs'] = [] | |
if 'contexts' not in st.session_state: | |
st.session_state['contexts'] = [] | |
if 'chunks' not in st.session_state: | |
st.session_state['chunks'] = [] | |
if 'user_input' not in st.session_state: | |
st.session_state['user_input'] = "" | |
if 'dataset' not in st.session_state: | |
st.session_state['dataset'] = None | |
def check_api_keys() -> bool: | |
source_id = app.params['source_id'] | |
index_id = app.params['index_id'] | |
open_api_key = os.getenv('OPENAI_API_KEY', '') | |
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 | |
pinecone_api_key = os.getenv('PINECONE_API_KEY', '') | |
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True | |
is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False | |
return is_ready | |
def check_combination_point() -> bool: | |
type_id = app.params['type_id'] | |
open_api_key = os.getenv('OPENAI_API_KEY', '') | |
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 | |
api_base = app.params['api_base'] | |
if type_id == 1: | |
deployment_id = app.params['deployment_id'] | |
return True if openapi_api_key_ready and api_base and deployment_id else False | |
elif type_id == 2: | |
return True if openapi_api_key_ready and api_base else False | |
else: | |
return False | |
def check_index() -> bool: | |
dataset = st.session_state['dataset'] | |
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False | |
without_source = app.params['source_id'] == 4 | |
is_ready = True if index_built or without_source else False | |
return is_ready | |
def check_index_point() -> bool: | |
index_id = app.params['index_id'] | |
pinecone_api_key = os.getenv('PINECONE_API_KEY', '') | |
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True | |
pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True | |
is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False | |
return is_ready | |
def check_params_point() -> bool: | |
max_sources = app.params['max_sources'] | |
temperature = app.params['temperature'] | |
is_ready = True if max_sources and isinstance(temperature, float) else False | |
return is_ready | |
def check_source_point() -> bool: | |
return True | |
def clear_chat_history(): | |
if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']: | |
st.session_state['past'] = [] | |
st.session_state['generated'] = [] | |
st.session_state['contexts'] = [] | |
st.session_state['chunks'] = [] | |
st.session_state['costs'] = [] | |
def clear_index(): | |
if dataset := st.session_state['dataset']: | |
# delete directory (with files) | |
index_path = dataset.index_path | |
if index_path.exists(): | |
shutil.rmtree(str(index_path)) | |
# update variable | |
st.session_state['dataset'] = None | |
elif (TEMP_DIR / "default").exists(): | |
shutil.rmtree(str(TEMP_DIR / "default")) | |
def check_sources() -> bool: | |
uploaded_files_rows = app.params['uploaded_files_rows'] | |
urls_df = app.params['urls_df'] | |
source_id = app.params['source_id'] | |
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False | |
some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) | |
only_local_files = some_files and not some_urls | |
only_urls = not some_files and some_urls | |
is_ready = only_local_files or only_urls or (source_id == 4) | |
return is_ready | |
def collect_dataset_and_built_index(): | |
start = time.time() | |
uploaded_files_rows = app.params['uploaded_files_rows'] | |
urls_df = app.params['urls_df'] | |
type_id = app.params['type_id'] | |
temperature = app.params['temperature'] | |
index_id = app.params['index_id'] | |
api_base = app.params['api_base'] | |
deployment_id = app.params['deployment_id'] | |
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False | |
some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) | |
openai.api_type = "azure" if type_id == 1 else "open_ai" | |
openai.api_base = api_base | |
openai.api_version = "2023-03-15-preview" if type_id == 1 else None | |
if deployment_id != "text-davinci-003": | |
dataset = Dataset( | |
llm=ChatOpenAI( | |
temperature=temperature, | |
max_tokens=512, | |
deployment_id=deployment_id, | |
) | |
) | |
else: | |
dataset = Dataset( | |
llm=OpenAI( | |
temperature=temperature, | |
max_tokens=512, | |
deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'), | |
) | |
) | |
# get url documents | |
if some_urls: | |
urls_df = urls_df.reset_index() | |
for url_index, url_row in urls_df.iterrows(): | |
url = url_row.get('urls', '') | |
citation = url_row.get('citation string', '') | |
if url: | |
try: | |
dataset.add( | |
url, | |
citation, | |
citation, | |
disable_check=True # True to accept Japanese letters | |
) | |
except Exception as e: | |
print(e) | |
pass | |
# dataset is stored as a pandas dataframe | |
if some_files: | |
for uploaded_files_row in uploaded_files_rows: | |
key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None | |
dataset.add( | |
uploaded_files_row.get('filepath'), | |
uploaded_files_row.get('citation string'), | |
key=key, | |
disable_check=True # True to accept Japanese letters | |
) | |
openai_embeddings = OpenAIEmbeddings( | |
document_model_name="text-embedding-ada-002", | |
query_model_name="text-embedding-ada-002", | |
) | |
if index_id == 1: | |
dataset._build_faiss_index(openai_embeddings) | |
else: | |
dataset._build_pinecone_index(openai_embeddings) | |
st.session_state['dataset'] = dataset | |
if OPERATING_MODE == "debug": | |
print(f"time to collect dataset: {time.time() - start:.2f} [s]") | |
def configure_streamlit_and_page(): | |
# Configure Streamlit page and state | |
st.set_page_config(**ST_CONFIG) | |
# Force responsive layout for columns also on mobile | |
st.write( | |
"""<style> | |
[data-testid="column"] { | |
width: calc(50% - 1rem); | |
flex: 1 1 calc(50% - 1rem); | |
min-width: calc(50% - 1rem); | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
def get_answer(): | |
query = st.session_state['user_input'] | |
dataset = st.session_state['dataset'] | |
type_id = app.params['type_id'] | |
index_id = app.params['index_id'] | |
max_sources = app.params['max_sources'] | |
if query and dataset and type_id and index_id: | |
chat_history = [(past, generated) | |
for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])] | |
marginal_relevance = False if not index_id == 1 else True | |
start = time.time() | |
openai_embeddings = OpenAIEmbeddings( | |
document_model_name="text-embedding-ada-002", | |
query_model_name="text-embedding-ada-002", | |
) | |
result = dataset.query( | |
query, | |
openai_embeddings, | |
chat_history, | |
marginal_relevance=marginal_relevance, # if pinecone is used it must be False | |
) | |
if OPERATING_MODE == "debug": | |
print(f"time to get answer: {time.time() - start:.2f} [s]") | |
print("-" * 10) | |
# response = {'generated_text': result.formatted_answer} | |
# response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug | |
return result | |
else: | |
return None | |
def load_main_page(): | |
""" | |
Load the body of web app. | |
""" | |
# Streamlit HTML Markdown | |
# st.title <h1> # | |
# st.header <h2> ## | |
# st.subheader <h3> ### | |
st.markdown(f"## Retrieval Augmented Document Q&A Chat ({APP_VERSION})") | |
validate_status() | |
st.markdown(f"#### **Status**: {app.params['status']}") | |
# hidden div with anchor | |
st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True) | |
col1, col2, col3 = st.columns(3) | |
col1.button(label="clear index", type="primary", on_click=clear_index) | |
col2.button(label="clear conversation", type="primary", on_click=clear_chat_history) | |
col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True) | |
if st.session_state["generated"]: | |
for i in range(len(st.session_state["generated"])): | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
message(st.session_state['generated'][i], key=str(i)) | |
with st.expander("See context"): | |
st.write(st.session_state['contexts'][i]) | |
with st.expander("See chunks"): | |
st.write(st.session_state['chunks'][i]) | |
with st.expander("See costs"): | |
st.write(st.session_state['costs'][i]) | |
dataset = st.session_state['dataset'] | |
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False | |
without_source = app.params['source_id'] == 4 | |
enable_chat_button = index_built or without_source | |
st.text_input("You:", | |
key='user_input', | |
on_change=on_enter, | |
disabled=not enable_chat_button | |
) | |
st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True) | |
# hidden div with anchor | |
st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True) | |
def load_sidebar_page(): | |
st.sidebar.markdown("## Instructions") | |
# ############ # | |
# SOURCES TYPE # | |
# ############ # | |
st.sidebar.markdown("1. Select a source:") | |
source_selected = st.sidebar.selectbox( | |
"Choose the location of your info to give context to chatgpt", | |
[key for key, value in SOURCES_IDS.items()]) | |
app.params['source_id'] = SOURCES_IDS.get(source_selected, None) | |
# ##### # | |
# MODEL # | |
# ##### # | |
st.sidebar.markdown("2. Select a model (LLM):") | |
combination_selected = st.sidebar.selectbox( | |
"Choose type: MSF Azure OpenAI and model / OpenAI", | |
[key for key, value in TYPE_IDS.items()]) | |
app.params['type_id'] = TYPE_IDS.get(combination_selected, None) | |
if app.params['type_id'] == 1: # with AzureOpenAI endpoint | |
# https://docs.streamlit.io/library/api-reference/widgets/st.text_input | |
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( | |
label="Enter Azure OpenAI API Key", | |
type="password" | |
).strip() | |
app.params['api_base'] = st.sidebar.text_input( | |
label="Enter Azure API base", | |
placeholder="https://<api_base_endpoint>.openai.azure.com/", | |
).strip() | |
app.params['deployment_id'] = st.sidebar.text_input( | |
label="Enter Azure deployment_id", | |
).strip() | |
elif app.params['type_id'] == 2: # with OpenAI endpoint | |
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( | |
label="Enter OpenAI API Key", | |
placeholder="sk-...", | |
type="password" | |
).strip() | |
app.params['api_base'] = "https://api.openai.com/v1" | |
app.params['deployment_id'] = None | |
# ####### # | |
# INDEXES # | |
# ####### # | |
st.sidebar.markdown("3. Select a index store:") | |
index_selected = st.sidebar.selectbox( | |
"Type of Index", | |
[key for key, value in INDEX_IDS.items()]) | |
app.params['index_id'] = INDEX_IDS.get(index_selected, None) | |
if app.params['index_id'] == 2: # with pinecone | |
os.environ['PINECONE_API_KEY'] = st.sidebar.text_input( | |
label="Enter pinecone API Key", | |
type="password" | |
).strip() | |
os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input( | |
label="Enter pinecone environment", | |
placeholder="eu-west1-gcp", | |
).strip() | |
# ############## # | |
# CONFIGURATIONS # | |
# ############## # | |
st.sidebar.markdown("4. Choose configuration:") | |
# https://docs.streamlit.io/library/api-reference/widgets/st.number_input | |
max_sources = st.sidebar.number_input( | |
label="Top-k: Number of chunks/sections (1-5)", | |
step=1, | |
format="%d", | |
value=5 | |
) | |
app.params['max_sources'] = max_sources | |
temperature = st.sidebar.number_input( | |
label="Temperature (0.0 – 1.0)", | |
step=0.1, | |
format="%f", | |
value=0.0, | |
min_value=0.0, | |
max_value=1.0 | |
) | |
app.params['temperature'] = round(temperature, 1) | |
# ############## # | |
# UPLOAD SOURCES # | |
# ############## # | |
app.params['uploaded_files_rows'] = [] | |
if app.params['source_id'] == 1: | |
# https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader | |
# https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb | |
st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""") | |
uploaded_files = st.sidebar.file_uploader( | |
"Choose files", | |
accept_multiple_files=True, | |
type=['pdf', 'PDF', | |
'txt', 'TXT', | |
'html', | |
'docx', 'DOCX', | |
'pptx', 'PPTX', | |
], | |
) | |
uploaded_files_dataset = request_pathname(uploaded_files) | |
uploaded_files_df = pd.DataFrame( | |
uploaded_files_dataset, | |
columns=['filepath', 'citation string']) | |
uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df) | |
uploaded_files_grid_options_builder.configure_selection( | |
selection_mode='multiple', | |
pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [], | |
use_checkbox=True, | |
) | |
uploaded_files_grid_options_builder.configure_column("citation string", editable=True) | |
uploaded_files_grid_options_builder.configure_auto_height() | |
uploaded_files_grid_options = uploaded_files_grid_options_builder.build() | |
with st.sidebar: | |
uploaded_files_ag_grid = AgGrid( | |
uploaded_files_df, | |
gridOptions=uploaded_files_grid_options, | |
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, | |
) | |
app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"] | |
app.params['urls_df'] = pd.DataFrame() | |
if app.params['source_id'] == 3: | |
st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""") | |
# option 1: with streamlit version 1.20.0+ | |
# app.params['urls_df'] = st.sidebar.experimental_data_editor( | |
# pd.DataFrame([["", ""]], columns=['urls', 'citation string']), | |
# use_container_width=True, | |
# num_rows="dynamic", | |
# ) | |
# option 2: with streamlit version 1.19.0 | |
urls_dataset = [["", ""], | |
["", ""], | |
["", ""], | |
["", ""], | |
["", ""]] | |
urls_df = pd.DataFrame( | |
urls_dataset, | |
columns=['urls', 'citation string']) | |
urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df) | |
urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True) | |
urls_grid_options_builder.configure_auto_height() | |
urls_grid_options = urls_grid_options_builder.build() | |
with st.sidebar: | |
urls_ag_grid = AgGrid( | |
urls_df, | |
gridOptions=urls_grid_options, | |
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, | |
) | |
df = urls_ag_grid.data | |
df = df[df.urls != ""] | |
app.params['urls_df'] = df | |
if app.params['source_id'] in (1, 2, 3): | |
st.sidebar.markdown("""6. Build an index where you can ask""") | |
api_keys_ready = check_api_keys() | |
source_ready = check_sources() | |
enable_index_button = api_keys_ready and source_ready | |
if st.sidebar.button("Build index", disabled=not enable_index_button): | |
collect_dataset_and_built_index() | |
def main(): | |
configure_streamlit_and_page() | |
load_sidebar_page() | |
load_main_page() | |
def on_enter(): | |
output = get_answer() | |
if output: | |
st.session_state.past.append(st.session_state['user_input']) | |
st.session_state.generated.append(output.answer) | |
st.session_state.contexts.append(output.context) | |
st.session_state.chunks.append(output.chunks) | |
st.session_state.costs.append(output.cost_str) | |
st.session_state['user_input'] = "" | |
def request_pathname(files): | |
if not files: | |
return [["", ""]] | |
# check if temp directory exist, if not create it | |
if not Path.exists(TEMP_DIR): | |
TEMP_DIR.mkdir( | |
parents=True, | |
exist_ok=True, | |
) | |
file_paths = [] | |
for file in files: | |
# # absolut path | |
# file_path = str(TEMP_DIR / file.name) | |
# relative path | |
file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR)) | |
file_paths.append(file_path) | |
with open(file_path, "wb") as f: | |
f.write(file.getbuffer()) | |
return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)] | |
def validate_status(): | |
source_point_ready = check_source_point() | |
combination_point_ready = check_combination_point() | |
index_point_ready = check_index_point() | |
params_point_ready = check_params_point() | |
sources_ready = check_sources() | |
index_ready = check_index() | |
if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready: | |
app.params['status'] = "✨Ready✨" | |
elif not source_point_ready: | |
app.params['status'] = "⚠️Review step 1 on the sidebar." | |
elif not combination_point_ready: | |
app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..." | |
elif not index_point_ready: | |
app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment." | |
elif not params_point_ready: | |
app.params['status'] = "⚠️Review step 4 on the sidebar" | |
elif not sources_ready: | |
app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..." | |
elif not index_ready: | |
app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..." | |
else: | |
app.params['status'] = "⚠️Something is not ready..." | |
class StreamlitLangchainChatApp(): | |
def __init__(self) -> None: | |
"""Use __init__ to define instance variables. It cannot have any arguments.""" | |
self.params = dict() | |
def run(self, **state) -> None: | |
"""Define here all logic required by your application.""" | |
main() | |
if __name__ == "__main__": | |
app = StreamlitLangchainChatApp() | |
app.run() | |