Yoxas commited on
Commit
d1b7d86
·
verified ·
1 Parent(s): 1f1368a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -40,7 +40,7 @@ else:
40
  gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU
41
 
42
  # Ensure embeddings are stacked as float32
43
- embeddings = np.vstack(data['embedding'].values).astype(np.float32)
44
  logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
45
  gpu_index.add(embeddings)
46
 
@@ -61,7 +61,7 @@ def embed_question(question, model, tokenizer):
61
  logging.debug(f"Tokenized inputs: {inputs}")
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
- embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float32)
65
  logging.debug(f"Question embedding shape: {embedding.shape}")
66
  logging.debug(f"Question embedding content: {embedding}")
67
  return embedding
@@ -78,7 +78,7 @@ def retrieve_and_generate(question):
78
  question_embedding = embed_question(question, model, tokenizer)
79
 
80
  # Ensure the embedding is in the correct format for FAISS search
81
- question_embedding = question_embedding.astype(np.float32)
82
 
83
  # Search in FAISS index
84
  try:
 
40
  gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU
41
 
42
  # Ensure embeddings are stacked as float32
43
+ embeddings = np.vstack(data['embedding'].values).astype(np.float16)
44
  logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
45
  gpu_index.add(embeddings)
46
 
 
61
  logging.debug(f"Tokenized inputs: {inputs}")
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float16)
65
  logging.debug(f"Question embedding shape: {embedding.shape}")
66
  logging.debug(f"Question embedding content: {embedding}")
67
  return embedding
 
78
  question_embedding = embed_question(question, model, tokenizer)
79
 
80
  # Ensure the embedding is in the correct format for FAISS search
81
+ question_embedding = question_embedding.astype(np.float16)
82
 
83
  # Search in FAISS index
84
  try: