KoonJamesZ commited on
Commit
4b2df1b
·
verified ·
1 Parent(s): bae5dd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
  import pandas as pd
 
 
 
3
  from FlagEmbedding import BGEM3FlagModel
4
 
5
  # Load the pre-trained embedding model
@@ -11,31 +14,26 @@ df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
11
 
12
  # Filter out any rows where 'embeding_context' might be empty or invalid
13
  df = df[df['embeding_context'] != '']
 
 
 
14
 
15
- # Encode the 'embeding_context' column
16
- embedding_contexts = df['embeding_context'].tolist()
17
- embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=2048)['dense_vecs']
18
  # Function to perform search and return all columns
19
  def search_query(query_text):
20
  num_records = 50
21
 
22
  # Encode the input query text
23
  embeddings_query = model.encode([query_text], batch_size=12, max_length=2048)['dense_vecs']
24
-
25
- # Compute similarity between the query and the CSV embeddings
26
- similarity_matrix = embeddings_query @ embeddings_csv.T
27
 
28
- # Rank records by similarity and select the top 'num_records'
29
- similarity_scores = similarity_matrix.max(axis=0)
30
- top_indices = similarity_scores.argsort()[-num_records:][::-1]
31
 
32
- # Get the top results and return all columns
33
- result_df = df.iloc[top_indices].drop_duplicates(subset=df.columns.difference(['embedding_context']), keep='first')
34
 
35
-
36
  return result_df
37
 
38
-
39
  # Gradio interface function
40
  def gradio_interface(query_text):
41
  search_results = search_query(query_text)
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import faiss
4
+ import numpy as np
5
+ import os
6
  from FlagEmbedding import BGEM3FlagModel
7
 
8
  # Load the pre-trained embedding model
 
14
 
15
  # Filter out any rows where 'embeding_context' might be empty or invalid
16
  df = df[df['embeding_context'] != '']
17
+
18
+ index = faiss.read_index('vector_store.index')
19
+
20
 
 
 
 
21
  # Function to perform search and return all columns
22
  def search_query(query_text):
23
  num_records = 50
24
 
25
  # Encode the input query text
26
  embeddings_query = model.encode([query_text], batch_size=12, max_length=2048)['dense_vecs']
27
+ embeddings_query_np = np.array(embeddings_query).astype('float32')
 
 
28
 
29
+ # Search in FAISS index for nearest neighbors
30
+ distances, indices = index.search(embeddings_query_np, num_records)
 
31
 
32
+ # Get the top results based on FAISS indices
33
+ result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True)
34
 
 
35
  return result_df
36
 
 
37
  # Gradio interface function
38
  def gradio_interface(query_text):
39
  search_results = search_query(query_text)