import gradio as gr import pandas as pd import numpy as np from transformers import pipeline, BertTokenizer, BertModel import faiss import torch import json import spaces import logging # Set up logging logging.basicConfig(level=logging.DEBUG) # Load CSV data data = pd.read_csv('RBDx10kstats.csv') # Function to safely convert JSON strings to numpy arrays def safe_json_loads(x): try: return np.array(json.loads(x), dtype=np.float32) # Ensure the array is of type float32 except json.JSONDecodeError as e: logging.error(f"Error decoding JSON: {e}") return np.array([], dtype=np.float32) # Return an empty array or handle it as appropriate # Apply the safe_json_loads function to the embedding column data['embedding'] = data['embedding'].apply(safe_json_loads) # Filter out any rows with empty embeddings data = data[data['embedding'].apply(lambda x: x.size > 0)] # Initialize FAISS index dimension = len(data['embedding'].iloc[0]) res = faiss.StandardGpuResources() # use a single GPU # Create FAISS index if faiss.get_num_gpus() > 0: gpu_index = faiss.IndexFlatL2(dimension) gpu_index = faiss.index_cpu_to_gpu(res, 0, gpu_index) # move to GPU else: gpu_index = faiss.IndexFlatL2(dimension) # fall back to CPU # Ensure embeddings are stacked as float32 embeddings = np.vstack(data['embedding'].values).astype(np.float32) logging.debug(f"Embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}") gpu_index.add(embeddings) # Check if GPU is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load QA model qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if torch.cuda.is_available() else -1) # Load BERT model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased').to(device) # Function to embed the question using BERT def embed_question(question, model, tokenizer): try: inputs = tokenizer(question, return_tensors='pt').to(device) logging.debug(f"Tokenized inputs: {inputs}") with torch.no_grad(): outputs = model(**inputs) embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().astype(np.float32) logging.debug(f"Question embedding shape: {embedding.shape}") logging.debug(f"Question embedding content: {embedding}") return embedding except Exception as e: logging.error(f"Error embedding question: {e}") raise # Function to retrieve the relevant document and generate a response @spaces.GPU(duration=120) def retrieve_and_generate(question): logging.debug(f"Received question: {question}") try: # Embed the question question_embedding = embed_question(question, model, tokenizer) # Ensure the embedding is in the correct format for FAISS search question_embedding = question_embedding.astype(np.float32) # Search in FAISS index try: logging.debug(f"Searching FAISS index with question embedding: {question_embedding}") _, indices = gpu_index.search(question_embedding, k=1) if indices.size == 0: logging.error("No results found in FAISS search.") return "No relevant document found." logging.debug(f"Indices found: {indices}") except Exception as e: logging.error(f"Error during FAISS search: {e}") return f"An error occurred during search: {e}" # Retrieve the most relevant document try: relevant_doc = data.iloc[indices[0][0]] logging.debug(f"Relevant document: {relevant_doc}") except Exception as e: logging.error(f"Error retrieving document: {e}") return "An error occurred while retrieving the document. Please try again." # Use the QA model to generate the answer try: context = relevant_doc['Abstract'] response = qa_model(question=question, context=context) logging.debug(f"Response: {response}") return response['answer'] except Exception as e: logging.error(f"Error generating answer: {e}") return "An error occurred while generating the answer. Please try again." except Exception as e: logging.error(f"Error during retrieval and generation: {e}") return "An error occurred. Please try again." # Create a Gradio interface interface = gr.Interface( fn=retrieve_and_generate, inputs=gr.Textbox(lines=2, placeholder="Ask a question about the documents..."), outputs="text", title="RAG Chatbot", description="Ask questions about the documents in the CSV file." ) # Launch the Gradio app interface.launch()