File size: 5,423 Bytes
37c2a8d
 
 
 
b1179cf
37c2a8d
 
 
 
 
 
b1179cf
 
37c2a8d
 
 
 
 
 
b1179cf
 
37c2a8d
 
 
 
 
 
 
 
b1179cf
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
b1179cf
37c2a8d
 
 
 
 
 
b1179cf
 
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1179cf
 
 
 
 
37c2a8d
 
 
 
b1179cf
 
 
 
 
 
 
 
 
 
 
 
37c2a8d
 
b1179cf
 
37c2a8d
 
 
 
 
 
 
b1179cf
37c2a8d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import src.constants.config as configurations
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from src.constants.credentials import cohere_trial_key, mixedbread_key
import streamlit as st
from src.reader import Reader
from src.utils_search import UtilsSearch
from copy import deepcopy
import numpy as np
import cohere
from mixedbread_ai.client import MixedbreadAI
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset


configurations = configurations.service_mxbai_msc_direct_config
api_key = cohere_trial_key
co = cohere.Client(api_key)
semantic_column_names = configurations["semantic_column_names"]
model = MixedbreadAI(api_key=mixedbread_key)
cross_encoder_name = configurations["cross_encoder_name"]

@st.cache_data
def init():
    config = configurations
    search_utils = UtilsSearch(config)
    reader = Reader(config=config["reader_config"])
    df = reader.read()
    index = search_utils.dataframe_to_index(df)
    return df, index, search_utils

def get_possible_values_for_column(column_name, search_utils, df):
    if column_name not in st.session_state:
        setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
    return getattr(st.session_state, column_name)


# Initialize or retrieve from session state
if 'init_results' not in st.session_state:
    st.session_state.init_results = init()

# Now you can access your initialized objects directly from the session state
df, index, search_utils = st.session_state.init_results

# Streamlit app layout
st.title('Search Demo')

# Input fields
query = st.text_input('Enter your search query here')
# use_cohere = st.checkbox('Use Cohere', value=False)  # Default to checked
use_cohere = False

programmatic_search_config = deepcopy(configurations['programmatic_search_config'])

dynamic_programmatic_search_config = {
    "scalar_columns": [],
    "discrete_columns": []
}


for column in programmatic_search_config['scalar_columns']:
    # Create number input for scalar values
    col_name = column["column_name"]
    min_val = float(column["min_value"])
    max_val = float(column["max_value"])
    user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
    user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
    dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})

for column in programmatic_search_config['discrete_columns']:
    # Create multiselect for discrete values
    col_name = column["column_name"]
    default_values = column["default_values"]
    # Assuming you have a function to fetch possible values for the discrete columns based on the column name
    possible_values = get_possible_values_for_column(col_name, search_utils, df)  # Implement this function based on your application
    selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
    dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})


programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']


# Search button
if st.button('Search'):
    if query:  # Checking if a query was entered
        df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True)
        df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config)
        df_filtered = df_filtered.sort_values(by='similarities', ascending=True)
        df_filtered = df_filtered[:100].reset_index(drop=True)

        if len(df_filtered) == 0:
            st.write('No results found')
        else:
            if use_cohere == False:
                records = df_filtered.to_dict(orient='records')
                dataset_str = SchemaStringDataset(records, configurations)
                documents = [batch["inputs"][:256] for batch in dataset_str]
                res = model.reranking(
                    model=cross_encoder_name,
                    query=query,
                    input=documents,
                    top_k=10,
                    return_input=False
                )
                ids = [item.index for item in res.data]
                results_df = df_filtered.loc[ids]

            else:
                df_filtered.fillna(value="", inplace=True)
                docs = df_filtered.to_dict('records')
                column_names = semantic_column_names
                docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
                rank_fields = list(docs[0].keys())
                results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
                                    rank_fields=rank_fields)
                top_ids = [hit.index for hit in results.results]
                # Create the DataFrame with the rerank results
                results_df = df_filtered.iloc[top_ids].copy()
                results_df['rank'] = (np.arange(len(results_df)) + 1)

            st.write(results_df)
    else:
        st.write("Please enter a query to search.")