File size: 2,022 Bytes
e7abd03
 
716f829
 
e7abd03
caf5793
71641f2
236d6fa
 
 
716f829
c30444f
716f829
e7abd03
716f829
d0ab9b7
 
 
 
e7abd03
caf5793
c30444f
 
e7abd03
716f829
 
d0ab9b7
716f829
 
 
b71f887
d0ab9b7
716f829
 
caf5793
 
 
 
d0ab9b7
e7abd03
 
716f829
 
caf5793
6f729e6
e7abd03
caf5793
6f729e6
 
716f829
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

# Ensure you have GPU support
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the CSV file with embeddings
df = pd.read_csv('RBDx10kstats.csv')
df['embedding'] = df['embedding'].apply(json.loads)  # Convert JSON string back to list

# Convert embeddings to tensor for efficient retrieval
embeddings = torch.tensor(df['embedding'].tolist(), device=device)

# Load the Sentence Transformer model
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

# Load the LLaMA model for response generation
llama_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
llama_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(device)

# Define the function to find the most relevant document
def retrieve_relevant_doc(query):
    query_embedding = model.encode(query, convert_to_tensor=True, device=device)
    similarities = util.pytorch_cos_sim(query_embedding, embeddings)[0]
    best_match_idx = torch.argmax(similarities).item()
    return df.iloc[best_match_idx]['Abstract']

# Define the function to generate a response
def generate_response(query):
    relevant_doc = retrieve_relevant_doc(query)
    input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
    inputs = llama_tokenizer(input_text, return_tensors="pt").to(device)
    outputs = llama_model.generate(inputs["input_ids"], max_length=150)
    response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Create a Gradio interface
iface = gr.Interface(
    fn=generate_response,
    inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your query here..."),
    outputs="text",
    title="RAG Chatbot",
    description="This chatbot retrieves relevant documents based on your query and generates responses using LLaMA."
)

# Launch the Gradio interface
iface.launch()