Spaces:
Running
Running
import torch | |
import src.constants.config as configurations | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers import CrossEncoder | |
from src.constants.credentials import cohere_trial_key, mixedbread_key | |
import streamlit as st | |
from src.reader import Reader | |
from src.utils_search import UtilsSearch | |
from copy import deepcopy | |
import numpy as np | |
import cohere | |
from mixedbread_ai.client import MixedbreadAI | |
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset | |
configurations = configurations.service_mxbai_msc_direct_config | |
api_key = cohere_trial_key | |
co = cohere.Client(api_key) | |
semantic_column_names = configurations["semantic_column_names"] | |
model = MixedbreadAI(api_key=mixedbread_key) | |
cross_encoder_name = configurations["cross_encoder_name"] | |
def init(): | |
config = configurations | |
search_utils = UtilsSearch(config) | |
reader = Reader(config=config["reader_config"]) | |
df = reader.read() | |
index = search_utils.dataframe_to_index(df) | |
return df, index, search_utils | |
def get_possible_values_for_column(column_name, search_utils, df): | |
if column_name not in st.session_state: | |
setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name)) | |
return getattr(st.session_state, column_name) | |
# Initialize or retrieve from session state | |
if 'init_results' not in st.session_state: | |
st.session_state.init_results = init() | |
# Now you can access your initialized objects directly from the session state | |
df, index, search_utils = st.session_state.init_results | |
# Streamlit app layout | |
st.title('Search Demo') | |
# Input fields | |
query = st.text_input('Enter your search query here') | |
# use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked | |
use_cohere = False | |
programmatic_search_config = deepcopy(configurations['programmatic_search_config']) | |
dynamic_programmatic_search_config = { | |
"scalar_columns": [], | |
"discrete_columns": [] | |
} | |
for column in programmatic_search_config['scalar_columns']: | |
# Create number input for scalar values | |
col_name = column["column_name"] | |
min_val = float(column["min_value"]) | |
max_val = float(column["max_value"]) | |
user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val) | |
user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val) | |
dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max}) | |
for column in programmatic_search_config['discrete_columns']: | |
# Create multiselect for discrete values | |
col_name = column["column_name"] | |
default_values = column["default_values"] | |
# Assuming you have a function to fetch possible values for the discrete columns based on the column name | |
possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application | |
selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values) | |
dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values}) | |
programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns'] | |
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns'] | |
# Search button | |
if st.button('Search'): | |
if query: # Checking if a query was entered | |
df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True) | |
df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config) | |
df_filtered = df_filtered.sort_values(by='similarities', ascending=True) | |
df_filtered = df_filtered[:100].reset_index(drop=True) | |
if len(df_filtered) == 0: | |
st.write('No results found') | |
else: | |
if use_cohere == False: | |
records = df_filtered.to_dict(orient='records') | |
dataset_str = SchemaStringDataset(records, configurations) | |
documents = [batch["inputs"][:256] for batch in dataset_str] | |
res = model.reranking( | |
model=cross_encoder_name, | |
query=query, | |
input=documents, | |
top_k=10, | |
return_input=False | |
) | |
ids = [item.index for item in res.data] | |
results_df = df_filtered.loc[ids] | |
else: | |
df_filtered.fillna(value="", inplace=True) | |
docs = df_filtered.to_dict('records') | |
column_names = semantic_column_names | |
docs = [{name: str(doc[name]) for name in column_names} for doc in docs] | |
rank_fields = list(docs[0].keys()) | |
results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0', | |
rank_fields=rank_fields) | |
top_ids = [hit.index for hit in results.results] | |
# Create the DataFrame with the rerank results | |
results_df = df_filtered.iloc[top_ids].copy() | |
results_df['rank'] = (np.arange(len(results_df)) + 1) | |
st.write(results_df) | |
else: | |
st.write("Please enter a query to search.") |