Yoxas commited on
Commit
9cc49fe
·
verified ·
1 Parent(s): d1b7d86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -17,10 +17,10 @@ data = pd.read_csv('RBDx10kstats.csv')
17
  # Function to safely convert JSON strings to numpy arrays
18
  def safe_json_loads(x):
19
  try:
20
- return np.array(json.loads(x), dtype=np.float16) # Ensure the array is of type float32
21
  except json.JSONDecodeError as e:
22
  logging.error(f"Error decoding JSON: {e}")
23
- return np.array([], dtype=np.float16) # Return an empty array or handle it as appropriate
24
 
25
  # Apply the safe_json_loads function to the embedding column
26
  data['embedding'] = data['embedding'].apply(safe_json_loads)
@@ -40,7 +40,7 @@ else:
40
  gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU
41
 
42
  # Ensure embeddings are stacked as float32
43
- embeddings = np.vstack(data['embedding'].values).astype(np.float16)
44
  logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
45
  gpu_index.add(embeddings)
46
 
@@ -61,7 +61,7 @@ def embed_question(question, model, tokenizer):
61
  logging.debug(f"Tokenized inputs: {inputs}")
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
- embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float16)
65
  logging.debug(f"Question embedding shape: {embedding.shape}")
66
  logging.debug(f"Question embedding content: {embedding}")
67
  return embedding
@@ -78,16 +78,16 @@ def retrieve_and_generate(question):
78
  question_embedding = embed_question(question, model, tokenizer)
79
 
80
  # Ensure the embedding is in the correct format for FAISS search
81
- question_embedding = question_embedding.astype(np.float16)
82
 
83
  # Search in FAISS index
84
  try:
85
  logging.debug(f"Searching FAISS index with question embedding: {question_embedding}")
86
- _, indices = gpu_index.search(question_embedding, k=1)
87
- if indices.size == 0:
88
  logging.error("No results found in FAISS search.")
89
  return "No relevant document found."
90
- logging.debug(f"Indices found: {indices}")
91
  except Exception as e:
92
  logging.error(f"Error during FAISS search: {e}")
93
  return f"An error occurred during search: {e}"
 
17
  # Function to safely convert JSON strings to numpy arrays
18
  def safe_json_loads(x):
19
  try:
20
+ return np.array(json.loads(x), dtype=np.float32) # Ensure the array is of type float32
21
  except json.JSONDecodeError as e:
22
  logging.error(f"Error decoding JSON: {e}")
23
+ return np.array([], dtype=np.float32) # Return an empty array or handle it as appropriate
24
 
25
  # Apply the safe_json_loads function to the embedding column
26
  data['embedding'] = data['embedding'].apply(safe_json_loads)
 
40
  gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU
41
 
42
  # Ensure embeddings are stacked as float32
43
+ embeddings = np.vstack(data['embedding'].values).astype(np.float32)
44
  logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
45
  gpu_index.add(embeddings)
46
 
 
61
  logging.debug(f"Tokenized inputs: {inputs}")
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float32)
65
  logging.debug(f"Question embedding shape: {embedding.shape}")
66
  logging.debug(f"Question embedding content: {embedding}")
67
  return embedding
 
78
  question_embedding = embed_question(question, model, tokenizer)
79
 
80
  # Ensure the embedding is in the correct format for FAISS search
81
+ question_embedding = question_embedding.astype(np.float32)
82
 
83
  # Search in FAISS index
84
  try:
85
  logging.debug(f"Searching FAISS index with question embedding: {question_embedding}")
86
+ distances, indices = gpu_index.search(question_embedding, k=1)
87
+ if len(indices) == 0:
88
  logging.error("No results found in FAISS search.")
89
  return "No relevant document found."
90
+ logging.debug(f"Indices found: {indices}, Distances: {distances}")
91
  except Exception as e:
92
  logging.error(f"Error during FAISS search: {e}")
93
  return f"An error occurred during search: {e}"