import gradio as gr import pandas as pd import numpy as np from transformers import pipeline, BertTokenizer, BertModel import faiss import torch import spaces # Load CSV data data = pd.read_csv('RB10kstats.csv') # Convert embedding column from string to numpy array data['embeddings'] = data['embeddings'].apply(lambda x: np.fromstring(x[1:-1], sep=', ')) # Initialize FAISS index dimension = len(data['embeddings'][0]) res = faiss.StandardGpuResources() # use a single GPU index = faiss.IndexFlatL2(dimension) gpu_index = faiss.index_cpu_to_gpu(res, 0, index) # move to GPU gpu_index.add(np.stack(data['embeddings'].values)) # Check if GPU is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load QA model qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if torch.cuda.is_available() else -1) # Load BERT model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased').to(device) # Function to embed the question using BERT @spaces.GPU(duration=120) def embed_question(question, model, tokenizer): inputs = tokenizer(question, return_tensors='pt').to(device) with torch.no_grad(): outputs = model(**inputs) return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Function to retrieve the relevant document and generate a response @spaces.GPU(duration=120) def retrieve_and_generate(question): # Embed the question question_embedding = embed_question(question, model, tokenizer) # Search in FAISS index _, indices = gpu_index.search(question_embedding, k=1) # Retrieve the most relevant document relevant_doc = data.iloc[indices[0][0]] # Use the QA model to generate the answer context = relevant_doc['Abstract'] response = qa_model(question=question, context=context) return response['answer'] # Create a Gradio interface interface = gr.Interface( fn=retrieve_and_generate, inputs=gr.inputs.Textbox(lines=2, placeholder="Ask a question about the documents..."), outputs="text", title="RAG Chatbot", description="Ask questions about the documents in the CSV file." ) # Launch the Gradio app interface.launch()