import torch import src.constants.config as configurations from sentence_transformers import SentenceTransformer from sentence_transformers import CrossEncoder from src.constants.credentials import cohere_trial_key, mixedbread_key import streamlit as st from src.reader import Reader from src.utils_search import UtilsSearch from copy import deepcopy import numpy as np import cohere from mixedbread_ai.client import MixedbreadAI from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset configurations = configurations.service_mxbai_msc_direct_config api_key = cohere_trial_key co = cohere.Client(api_key) semantic_column_names = configurations["semantic_column_names"] model = MixedbreadAI(api_key=mixedbread_key) cross_encoder_name = configurations["cross_encoder_name"] @st.cache_data def init(): config = configurations search_utils = UtilsSearch(config) reader = Reader(config=config["reader_config"]) df = reader.read() index = search_utils.dataframe_to_index(df) return df, index, search_utils def get_possible_values_for_column(column_name, search_utils, df): if column_name not in st.session_state: setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name)) return getattr(st.session_state, column_name) # Initialize or retrieve from session state if 'init_results' not in st.session_state: st.session_state.init_results = init() # Now you can access your initialized objects directly from the session state df, index, search_utils = st.session_state.init_results # Streamlit app layout st.title('Search Demo') # Input fields query = st.text_input('Enter your search query here') # use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked use_cohere = False programmatic_search_config = deepcopy(configurations['programmatic_search_config']) dynamic_programmatic_search_config = { "scalar_columns": [], "discrete_columns": [] } for column in programmatic_search_config['scalar_columns']: # Create number input for scalar values col_name = column["column_name"] min_val = float(column["min_value"]) max_val = float(column["max_value"]) user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val) user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val) dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max}) for column in programmatic_search_config['discrete_columns']: # Create multiselect for discrete values col_name = column["column_name"] default_values = column["default_values"] # Assuming you have a function to fetch possible values for the discrete columns based on the column name possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values) dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values}) programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns'] programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns'] # Search button if st.button('Search'): if query: # Checking if a query was entered df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True) df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config) df_filtered = df_filtered.sort_values(by='similarities', ascending=True) df_filtered = df_filtered[:100].reset_index(drop=True) if len(df_filtered) == 0: st.write('No results found') else: if use_cohere == False: records = df_filtered.to_dict(orient='records') dataset_str = SchemaStringDataset(records, configurations) documents = [batch["inputs"][:256] for batch in dataset_str] res = model.reranking( model=cross_encoder_name, query=query, input=documents, top_k=10, return_input=False ) ids = [item.index for item in res.data] results_df = df_filtered.loc[ids] else: df_filtered.fillna(value="", inplace=True) docs = df_filtered.to_dict('records') column_names = semantic_column_names docs = [{name: str(doc[name]) for name in column_names} for doc in docs] rank_fields = list(docs[0].keys()) results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0', rank_fields=rank_fields) top_ids = [hit.index for hit in results.results] # Create the DataFrame with the rerank results results_df = df_filtered.iloc[top_ids].copy() results_df['rank'] = (np.arange(len(results_df)) + 1) st.write(results_df) else: st.write("Please enter a query to search.")