Spaces:
Runtime error
Runtime error
File size: 2,792 Bytes
8b582d7 921072e 6c4a2c9 8f37923 8e2a2a1 259ef44 921072e 75e6466 6c4a2c9 259ef44 6c4a2c9 75e6466 6c4a2c9 8f37923 75e6466 6c4a2c9 921072e 6c4a2c9 921072e 8e2a2a1 921072e 6c4a2c9 921072e 6c4a2c9 921072e d186e31 921072e 5cb1a88 921072e 8f37923 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline, BertTokenizer, BertModel
import faiss
import torch
import json
import spaces
# Load CSV data
data = pd.read_csv('RBD10kstats.csv')
# Function to safely convert JSON strings to numpy arrays
def safe_json_loads(x):
try:
return np.array(json.loads(x))
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
return np.array([]) # Return an empty array or handle it as appropriate
# Apply the safe_json_loads function to the embedding column
data['embedding'] = data['embedding'].apply(safe_json_loads)
# Filter out any rows with empty embeddings
data = data[data['embedding'].apply(lambda x: x.size > 0)]
# Initialize FAISS index
dimension = len(data['embedding'][0])
res = faiss.StandardGpuResources() # use a single GPU
# Check available GPU devices
num_gpus = faiss.get_num_gpus()
if num_gpus > 0:
gpu_index = faiss.IndexFlatL2(dimension)
gpu_index = faiss.index_cpu_to_gpu(res, 0, gpu_index) # move to GPU
else:
raise RuntimeError("No GPU devices available.")
gpu_index.add(np.stack(data['embedding'].values))
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load QA model
qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if torch.cuda.is_available() else -1)
# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)
# Function to embed the question using BERT
def embed_question(question, model, tokenizer):
inputs = tokenizer(question, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
# Function to retrieve the relevant document and generate a response
@spaces.GPU(duration=120)
def retrieve_and_generate(question):
# Embed the question
question_embedding = embed_question(question, model, tokenizer)
# Search in FAISS index
_, indices = gpu_index.search(question_embedding, k=1)
# Retrieve the most relevant document
relevant_doc = data.iloc[indices[0][0]]
# Use the QA model to generate the answer
context = relevant_doc['Abstract']
response = qa_model(question=question, context=context)
return response['answer']
# Create a Gradio interface
interface = gr.Interface(
fn=retrieve_and_generate,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about the documents..."),
outputs="text",
title="RAG Chatbot",
description="Ask questions about the documents in the CSV file."
)
# Launch the Gradio app
interface.launch() |