Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,51 +1,84 @@
|
|
1 |
from datasets import load_dataset
|
2 |
-
from sentence_transformers import SentenceTransformer
|
3 |
-
import faiss
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
import chromadb
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
-
#
|
12 |
-
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
|
|
|
|
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
)
|
34 |
|
35 |
print("Stored embeddings in ChromaDB!")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Search function using ChromaDB
|
38 |
-
def search_wikipedia(query, top_k=3):
|
39 |
-
query_embedding = embed_model.encode([query]).tolist()
|
40 |
-
results = collection.query(
|
41 |
-
query_embeddings=query_embedding,
|
42 |
-
n_results=top_k
|
43 |
-
|
44 |
-
return "\n\n".join(results["documents"][0]) # Return top results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# Gradio Interface
|
47 |
iface = gr.Interface(
|
48 |
-
fn=
|
49 |
inputs="text",
|
50 |
outputs="text",
|
51 |
title="Wikipedia Search RAG",
|
|
|
1 |
from datasets import load_dataset
|
|
|
|
|
2 |
import numpy as np
|
3 |
import gradio as gr
|
4 |
import chromadb
|
5 |
|
6 |
+
from transformers import AutoModel, AutoTokenizer, pipeline
|
7 |
+
import torch
|
8 |
+
import chromadb
|
9 |
|
10 |
+
# Initialize ChromaDB client
|
11 |
+
chroma_client = chromadb.PersistentClient(path="./chroma_db") # Stores data persistently
|
12 |
+
collection = chroma_client.get_or_create_collection(name="wikipedia_docs")
|
13 |
|
14 |
+
# Load the BAAI embedding model
|
15 |
+
model_name = "BAAI/bge-base-en"
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
+
model = AutoModel.from_pretrained(model_name)
|
18 |
|
19 |
+
def get_embedding(text):
|
20 |
+
"""Generate embeddings using BAAI/bge-base-en."""
|
21 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
22 |
+
with torch.no_grad():
|
23 |
+
outputs = model(**inputs)
|
24 |
+
return outputs.last_hidden_state[:, 0, :].numpy().tolist() # Take CLS token embedding
|
25 |
|
26 |
+
# Load LLaMA Model (Meta LLaMA 2)
|
27 |
+
llama_pipe = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf")
|
28 |
|
29 |
+
# Load a small subset (10,000 rows)
|
30 |
+
#dataset = load_dataset("wiki40b", "en", split="train[:1000]")
|
31 |
|
32 |
+
# Extract only text
|
33 |
+
#docs = [d["text"] for d in dataset]
|
34 |
+
docs = ["Machine learning is a field of AI...", "Neural networks are inspired by the brain..."]
|
35 |
|
36 |
+
#print("Loaded dataset with", len(docs), "documents.")
|
37 |
+
|
38 |
+
# ✅ Step 2: Embed and Store in ChromaDB
|
39 |
+
for i, doc in enumerate(docs):
|
40 |
+
embedding = get_embedding(doc)
|
41 |
+
collection.add(ids=[str(i)], embeddings=[embedding], documents=[doc])
|
|
|
42 |
|
43 |
print("Stored embeddings in ChromaDB!")
|
44 |
|
45 |
+
# Store embeddings in ChromaDB
|
46 |
+
#for i, (doc, embedding) in enumerate(zip(docs, embeddings)):
|
47 |
+
# collection.add(
|
48 |
+
# ids=[str(i)], # Unique ID for each doc
|
49 |
+
# embeddings=[embedding.tolist()], # Convert numpy array to list
|
50 |
+
# documents=[doc]
|
51 |
+
# )
|
52 |
+
|
53 |
+
|
54 |
# Search function using ChromaDB
|
55 |
+
#def search_wikipedia(query, top_k=3):
|
56 |
+
# query_embedding = embed_model.encode([query]).tolist()
|
57 |
+
# results = collection.query(
|
58 |
+
# query_embeddings=query_embedding,
|
59 |
+
# n_results=top_k
|
60 |
+
|
61 |
+
#return "\n\n".join(results["documents"][0]) # Return top results
|
62 |
+
# return results["documents"][0] # Return top results
|
63 |
+
|
64 |
+
# Function to search ChromaDB & generate response
|
65 |
+
def query_llama(user_input):
|
66 |
+
query_embedding = get_embedding(user_input)
|
67 |
+
results = collection.query(query_embeddings=[query_embedding], n_results=3)
|
68 |
+
|
69 |
+
if not results["documents"]:
|
70 |
+
return "No relevant documents found."
|
71 |
+
|
72 |
+
context = " ".join(results["documents"][0])
|
73 |
+
prompt = f"Using this context, answer the question: {user_input}\nContext: {context}"
|
74 |
+
|
75 |
+
response = llama_pipe(prompt, max_length=200)
|
76 |
+
return f"**LLaMA Response:** {response[0]['generated_text']}\n\n**Retrieved Docs:** {context}"
|
77 |
+
|
78 |
|
79 |
# Gradio Interface
|
80 |
iface = gr.Interface(
|
81 |
+
fn=query_llama,
|
82 |
inputs="text",
|
83 |
outputs="text",
|
84 |
title="Wikipedia Search RAG",
|