import torch import src.constants.config as configurations from sentence_transformers import SentenceTransformer from sentence_transformers import CrossEncoder from src.constants.credentials import cohere_trial_key import streamlit as st from src.reader import Reader from src.utils_search import UtilsSearch from copy import deepcopy import numpy as np import cohere configurations = configurations.service_mxbai_msc_direct_config api_key = cohere_trial_key co = cohere.Client(api_key) semantic_column_names = configurations["semantic_column_names"] # Check CUDA availability and set device if torch.cuda.is_available(): torch.cuda.set_device(0) # Use the first GPU else: st.write("CUDA is not available. Using CPU instead.") @st.cache_data def init(): config = configurations search_utils = UtilsSearch(config) reader = Reader(config=config["reader_config"]) model = SentenceTransformer(config['sentence_transformer_name'], device='cuda:0') cross_encoder = CrossEncoder(config['cross_encoder_name'], device='cuda:0') df = reader.read() index = search_utils.dataframe_to_index(df) return df, model, cross_encoder, index, search_utils def get_possible_values_for_column(column_name, search_utils, df): if column_name not in st.session_state: setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name)) return getattr(st.session_state, column_name) # Initialize or retrieve from session state if 'init_results' not in st.session_state: st.session_state.init_results = init() # Now you can access your initialized objects directly from the session state df, model, cross_encoder, index, search_utils = st.session_state.init_results # Streamlit app layout st.title('Search Demo') # Input fields query = st.text_input('Enter your search query here') use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked programmatic_search_config = deepcopy(configurations['programmatic_search_config']) dynamic_programmatic_search_config = { "scalar_columns": [], "discrete_columns": [] } for column in programmatic_search_config['scalar_columns']: # Create number input for scalar values col_name = column["column_name"] min_val = float(column["min_value"]) max_val = float(column["max_value"]) user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val) user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val) dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max}) for column in programmatic_search_config['discrete_columns']: # Create multiselect for discrete values col_name = column["column_name"] default_values = column["default_values"] # Assuming you have a function to fetch possible values for the discrete columns based on the column name possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values) dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values}) programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns'] programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns'] # Search button if st.button('Search'): if query: # Checking if a query was entered df_filtered = search_utils.filter_dataframe(df, programmatic_search_config) if len(df_filtered) == 0: st.write('No results found') else: index = search_utils.dataframe_to_index(df_filtered) if use_cohere == False: # Call your Cohere-based search function here results_df = search_utils.search(query, df_filtered, model, cross_encoder, index) results_df = search_utils.drop_columns(results_df, programmatic_search_config) else: df_retrieved = search_utils.retrieve(query, df_filtered, model, index) df_retrieved = search_utils.drop_columns(df_retrieved, programmatic_search_config) df_retrieved.fillna(value="", inplace=True) docs = df_retrieved.to_dict('records') column_names = semantic_column_names docs = [{name: str(doc[name]) for name in column_names} for doc in docs] rank_fields = list(docs[0].keys()) results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0', rank_fields=rank_fields) top_ids = [hit.index for hit in results.results] # Create the DataFrame with the rerank results results_df = df_retrieved.iloc[top_ids].copy() results_df['rank'] = (np.arange(len(results_df)) + 1) st.write(results_df) else: st.write("Please enter a query to search.")