Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from transformers import pipeline, BertTokenizer, BertModel | |
import faiss | |
import torch | |
import json | |
import spaces | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Load CSV data | |
data = pd.read_csv('RBDx10kstats.csv') | |
# Function to safely convert JSON strings to numpy arrays | |
def safe_json_loads(x): | |
try: | |
return np.array(json.loads(x), dtype=np.float16) # Ensure the array is of type float32 | |
except json.JSONDecodeError as e: | |
logging.error(f"Error decoding JSON: {e}") | |
return np.array([], dtype=np.float16) # Return an empty array or handle it as appropriate | |
# Apply the safe_json_loads function to the embedding column | |
data['embedding'] = data['embedding'].apply(safe_json_loads) | |
# Filter out any rows with empty embeddings | |
data = data[data['embedding'].apply(lambda x: x.size > 0)] | |
# Initialize FAISS index | |
dimension = len(data['embedding'].iloc[0]) | |
res = faiss.StandardGpuResources() # use a single GPU | |
# Create FAISS index | |
if faiss.get_num_gpus() > 0: | |
gpu_index = faiss.IndexFlatL2(dimension) | |
gpu_index = faiss.index_cpu_to_gpu(res, 0, gpu_index) # move to GPU | |
else: | |
gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU | |
# Ensure embeddings are stacked as float32 | |
embeddings = np.vstack(data['embedding'].values).astype(np.float16) | |
logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}") | |
gpu_index.add(embeddings) | |
# Check if GPU is available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load QA model | |
qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if torch.cuda.is_available() else -1) | |
# Load BERT model and tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased').to(device) | |
# Function to embed the question using BERT | |
def embed_question(question, model, tokenizer): | |
try: | |
inputs = tokenizer(question, return_tensors='pt').to(device) | |
logging.debug(f"Tokenized inputs: {inputs}") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float16) | |
logging.debug(f"Question embedding shape: {embedding.shape}") | |
logging.debug(f"Question embedding content: {embedding}") | |
return embedding | |
except Exception as e: | |
logging.error(f"Error embedding question: {e}") | |
raise | |
# Function to retrieve the relevant document and generate a response | |
def retrieve_and_generate(question): | |
logging.debug(f"Received question: {question}") | |
try: | |
# Embed the question | |
question_embedding = embed_question(question, model, tokenizer) | |
# Ensure the embedding is in the correct format for FAISS search | |
question_embedding = question_embedding.astype(np.float16) | |
# Search in FAISS index | |
try: | |
logging.debug(f"Searching FAISS index with question embedding: {question_embedding}") | |
_, indices = gpu_index.search(question_embedding, k=1) | |
if indices.size == 0: | |
logging.error("No results found in FAISS search.") | |
return "No relevant document found." | |
logging.debug(f"Indices found: {indices}") | |
except Exception as e: | |
logging.error(f"Error during FAISS search: {e}") | |
return f"An error occurred during search: {e}" | |
# Retrieve the most relevant document | |
try: | |
relevant_doc = data.iloc[indices[0][0]] | |
logging.debug(f"Relevant document: {relevant_doc}") | |
except Exception as e: | |
logging.error(f"Error retrieving document: {e}") | |
return "An error occurred while retrieving the document. Please try again." | |
# Use the QA model to generate the answer | |
try: | |
context = relevant_doc['Abstract'] | |
response = qa_model(question=question, context=context) | |
logging.debug(f"Response: {response}") | |
return response['answer'] | |
except Exception as e: | |
logging.error(f"Error generating answer: {e}") | |
return "An error occurred while generating the answer. Please try again." | |
except Exception as e: | |
logging.error(f"Error during retrieval and generation: {e}") | |
return "An error occurred. Please try again." | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=retrieve_and_generate, | |
inputs=gr.Textbox(lines=2, placeholder="Ask a question about the documents..."), | |
outputs="text", | |
title="RAG Chatbot", | |
description="Ask questions about the documents in the CSV file." | |
) | |
# Launch the Gradio app | |
interface.launch() |